1

I have a spark application. My usecase is to allow users to define an arbitrary function that goes like Record => Record as a 'rule', that would apply on each record of an RDD/Dataset.

Following is the code:


    //Sample rows with Id, Name, DOB and address
    val row1 = "19283,Alan,1989-01-20,445 Mount Eden Road Mount Eden Auckland"
    val row2 = "15689,Ben,1989-01-20,445 Mount Eden Road Mount Eden Auckland"

    val record1 = new Record(
      new RecordMetadata(),
      row1,
      true
    )
    val record2 = new Record(
      new RecordMetadata(),
      row2,
      true
    )

    val inputRecsList = List(record1, record2)
    val inputRecs = spark.sparkContext.parallelize(inputRecsList)

    val rule = ScalaExpression(
      //Sample rule. A lambda (Record => Record)
      """
        | import model.Record
        | { record: Record => record }
      """.stripMargin

    val outputRecs = inputRecs.map(rule.transformation)

Following is the definition of 'Record' and 'RecordMetadata' and 'ScalaExpression' classes:

case class Record(
                   val metadata: RecordMetadata,
                   val row: String,
                   val isValidRecord: Boolean = true
                 ) extends Serializable

case class RecordMetadata() extends Serializable

case class ScalaExpression(function: Function1[Record, Record]) extends Rule {

  def transformation = function
}

object ScalaExpression{

  /**
    * @param Scala expression as a string
    * @return Evaluated result of type Function (Record => Record)
    */
  def apply(string: String) = {
    val toolbox = currentMirror.mkToolBox()
    val tree = toolbox.parse(string)
    val fn = toolbox.eval(tree).asInstanceOf[(Record => Record)] //Or Function1(Record, Record)
    new ScalaExpression(fn)
  }
}

The code above, throws a cryptic exception:

java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
    at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2287)
    at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1417)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2293)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2211)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2069)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2287)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2211)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2069)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:431)
    at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
    at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

The code however, works well if the rule is defined directly in code: val rule = ScalaExpression( {record: Record => record} )

The code also works well when if the map(with the runtime evaluated rule) is applied on a List, instead of RDD/Dataset.

Have been stuck for a while trying to make it work. Any help would be appreciated.

EDIT: The 'possible duplicate' flagged to this question is solving a completely different problem. My usecase tries to fetch a rule(a valid scala statement that converts one record into another) at runtime from a user, and that causes Serialization issues when trying to apply the rule to each record of a dataset.

Best Regards.

Ankit Khettry
  • 997
  • 1
  • 13
  • 33
  • 1
    Would it work if you do a mapPartition instead of a map ? The rule would be apply on each element of an iterator instead of a RDD. – Nonontb Jan 18 '19 at 10:07
  • Possible duplicate of [How to fix java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List to field type scala.collection.Seq?](https://stackoverflow.com/questions/39953245/how-to-fix-java-lang-classcastexception-cannot-assign-instance-of-scala-collect) – 10465355 Jan 18 '19 at 10:58
  • @Nonontb mapPartition indeed solved my problem. Creating an instance of the rule in every partition separately is working as a breeze. Thanks for the help :) – Ankit Khettry Jan 21 '19 at 08:36
  • 1
    @AnkitKhettry It would be nice to answer your own question with the final code solution. It may be help other people to find your question and the solution – Nonontb Jan 21 '19 at 08:54
  • Sure, will do. I thought it would be fair to let you answer the question and accept your answer, since its your suggestion that helped me resolve the issue :) @Nonontb – Ankit Khettry Jan 23 '19 at 12:14

1 Answers1

2

There is an open issue at Spark JIRA to fix this issue - SPARK-20525 The reason for this issue was due to mismatch of spark classloader when you are loading Spark UDF.

The resolution of this is to load your spark session after your interpreter. Please find the example code. Also you can refer to my github for example SparkCustomTransformations

trait CustomTransformations extends Serializable {
  def execute(spark: SparkSession, df: DataFrame, udfFunctions: AnyRef*): DataFrame
}

// IMPORTANT spark session should be lazy evaluated
lazy val spark = getSparkSession

def getInterpretor: scala.tools.nsc.interpreter.IMain = {

  import scala.tools.nsc.GenericRunnerSettings
  import scala.tools.nsc.interpreter.IMain

  val cl = ClassLoader.getSystemClassLoader
  val conf = new SparkConf()
  val settings = new GenericRunnerSettings(println _)
  settings.usejavacp.value = true

  val intp = new scala.tools.nsc.interpreter.IMain(settings, new java.io.PrintWriter(System.out))
  intp.setContextClassLoader
  intp.initializeSynchronous

  intp
}

val intp = getInterpretor

val udf_str =
  """
    (str:String)=>{
      str.toLowerCase
    }
    """
val customTransStr =
  """
    |import org.apache.spark.SparkConf
    |import org.apache.spark.sql.{DataFrame, SparkSession}
    |import org.apache.spark.sql.functions._
    |
    |new CustomTransformations {
    |    override def execute(spark: SparkSession, df: DataFrame, func: AnyRef*): DataFrame = {
    |
    |      //reading your UDF
    |      val str_lower_udf = spark.udf.register("str_lower", func(0).asInstanceOf[Function1[String,String]])
    |
    |      df.createOrReplaceTempView("df")
    |      val df_with_UDF_cols = spark.sql("select a.*, str_lower(a.fakeEventTag) as customUDFCol1 from df a").withColumn("customUDFCol2", str_lower_udf(col("fakeEventTag")))
    |
    |      df_with_UDF_cols.show()
    |      df_with_UDF_cols
    |    }
    |}
  """.stripMargin

intp.interpret(udf_str)
var udf_obj = intp.eval(udf_str)

val eval = new com.twitter.util.Eval
val customTransform: CustomTransformations = eval[CustomTransformations](customTransStr)


val sampleSparkDF = getSampleSparkDF
val outputDF = customTransform.execute(spark, sampleSparkDF, udf_obj)

outputDF.printSchema()
outputDF.show()
pavan
  • 21
  • 3