aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala14
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala24
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
4 files changed, 43 insertions, 7 deletions
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 4ee605fd7f..829b12269f 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -1221,10 +1221,16 @@ import org.apache.spark.annotation.DeveloperApi
)
}
- val preamble = """
- |class %s extends Serializable {
- | %s%s%s
- """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute))
+ val preamble = s"""
+ |class ${lineRep.readName} extends Serializable {
+ | ${envLines.map(" " + _ + ";\n").mkString}
+ | $importsPreamble
+ |
+ | // If we need to construct any objects defined in the REPL on an executor we will need
+ | // to pass the outer scope to the appropriate encoder.
+ | org.apache.spark.sql.catalyst.encoders.OuterScopes.addOuterScope(this)
+ | ${indentCode(toCompute)}
+ """.stripMargin
val postamble = importsTrailer + "\n}" + "\n" +
"object " + lineRep.readName + " {\n" +
" val INSTANCE = new " + lineRep.readName + "();\n" +
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 5674dcd669..081aa03002 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -262,6 +262,9 @@ class ReplSuite extends SparkFunSuite {
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
+ |
+ |// Test Dataset Serialization in the REPL
+ |Seq(TestCaseClass(1)).toDS().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -278,6 +281,27 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("java.lang.ClassNotFoundException", output)
}
+ test("Datasets and encoders") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
+ | def zero: Int = 0 // The initial value.
+ | def reduce(b: Int, a: Int) = b + a // Add an element to the running total
+ | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
+ | def finish(b: Int) = b // Return the final result.
+ |}.toColumn
+ |
+ |val ds = Seq(1, 2, 3, 4).toDS()
+ |ds.select(simpleSum).collect
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
test("SPARK-2632 importing a method from non serializable class and not using it.") {
val output = runInterpreter("local",
"""
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 3d2d235a00..a976e96809 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -65,7 +65,13 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
case e: ClassNotFoundException => {
val classOption = findClassLocally(name)
classOption match {
- case None => throw new ClassNotFoundException(name, e)
+ case None =>
+ // If this class has a cause, it will break the internal assumption of Janino
+ // (the compiler used for Spark SQL code-gen).
+ // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see
+ // its behavior will be changed if there is a cause and the compilation
+ // of generated class will fail.
+ throw new ClassNotFoundException(name)
case Some(a) => a
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 1b7260cdfe..2f3d6aeb86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types._
-
+import org.apache.spark.util.Utils
/**
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
@@ -536,7 +536,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*/
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
- evaluator.setParentClassLoader(getClass.getClassLoader)
+ evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
// Cannot be under package codegen, or fail with java.lang.InstantiationException
evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass")
evaluator.setDefaultImports(Array(