Scala DataSet with case class inheritance

ghz 8months ago ⋅ 111 views

I'd like to be able to store different related types in a Spark DataFrame but work with strongly typed case classes via a DataSet. E.g. say I have a Base trait and two case classes A and B that extend the trait:

trait Base {
  def name: String

case class A(name: String, number: Int) extends Base

case class B(name: String, text: String) extends Base

I'd like to create a val lb = List[Base](A("Alice", 20), B("Bob", "Foo")) and then create a DataFrame via lb.toDS(). Not surprisingly, this doesn't work as there is no encoder for the trait for it's different extended classes.

I could manually create a case class representing a structure that can hold information for both A and B:

case class Struct(typ: String, name: String, number: Option[Int] = None, text: Option[String] = None)

And I could add some functions to create a Struct from an instance of a Base trait and vice vera:

trait Base {
  def name: String

  def asStruct: Struct = {
    this match {
      case A(name, number) => Struct("A", name, number = Some(number))
      case B(name, text) => Struct("B", name, text = Some(text))

case class Struct(typ: String, name: String, number: Option[Int] = None, text: Option[String] = None) {
  def asBase: Base = {
    this match {
      case Struct("A", name, Some(number), None) => A(name, number)
      case Struct("B", name, None, Some(text)) => B(name, text)
      case _ => throw new Exception(f"Invalid Base structure {s}")

Then I can create my DataFrame as follows:

    val a = A("Alice", 32)
    val b = B("Bob", "foo")

    val ls = List[Struct](a.asStruct, b.asStruct)

    val sparkSession = spark
    import sparkSession.implicits._

    val df = ls.toDS()

|typ| name|number|text|
|  A|Alice|    32|NULL|
|  B|  Bob|  NULL| foo|

I can work with this approach but I wondered if it is possible to write an encoder that automatically treats a Base class as a Struct using the asStruct method written above?


Yes, it's possible to write a custom encoder in Spark to automatically treat instances of the Base trait as Struct objects using the asStruct method. You can achieve this by implementing a custom Encoder[Struct] and registering it with Spark's implicit encoders.

Here's how you can define the custom encoder:

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

implicit def baseToStructEncoder: Encoder[Struct] = ExpressionEncoder()

case class Struct(typ: String, name: String, number: Option[Int] = None, text: Option[String] = None)

trait Base {
  def name: String
  def asStruct: Struct

case class A(name: String, number: Int) extends Base {
  def asStruct: Struct = Struct("A", name, Some(number))

case class B(name: String, text: String) extends Base {
  def asStruct: Struct = Struct("B", name, text = Some(text))

With this setup, when you convert a List[Base] to a DataSet, Spark will use the custom encoder to serialize instances of Base as Struct objects, using the asStruct method. For example:

val a = A("Alice", 32)
val b = B("Bob", "foo")

val ls = List[Base](a, b)

val sparkSession = spark
import sparkSession.implicits._

val df = ls.toDS()

This should produce a DataFrame with the following structure:

|typ| name|number|text|
|  A|Alice|    32|NULL|
|  B|  Bob|  NULL| foo|

Now, Spark automatically treats instances of Base as Struct objects using the asStruct method.