diff options
author | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2015-12-18 14:05:06 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-12-18 14:05:06 -0800 |
commit | 6eba655259d2bcea27d0147b37d5d1e476e85422 (patch) | |
tree | f25a36ec77cb23da0ec848fa5b0b2f72cc8cb07d /sql/catalyst | |
parent | 41ee7c57abd9f52065fd7ffb71a8af229603371d (diff) | |
download | spark-6eba655259d2bcea27d0147b37d5d1e476e85422.tar.gz spark-6eba655259d2bcea27d0147b37d5d1e476e85422.tar.bz2 spark-6eba655259d2bcea27d0147b37d5d1e476e85422.zip |
[SPARK-12404][SQL] Ensure objects passed to StaticInvoke is Serializable
Now `StaticInvoke` receives `Any` as a object and `StaticInvoke` can be serialized but sometimes the object passed is not serializable.
For example, following code raises Exception because `RowEncoder#extractorsFor` invoked indirectly makes `StaticInvoke`.
```
case class TimestampContainer(timestamp: java.sql.Timestamp)
val rdd = sc.parallelize(1 to 2).map(_ => TimestampContainer(System.currentTimeMillis))
val df = rdd.toDF
val ds = df.as[TimestampContainer]
val rdd2 = ds.rdd <----------------- invokes extractorsFor indirectory
```
I'll add test cases.
Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Author: Michael Armbrust <michael@databricks.com>
Closes #10357 from sarutak/SPARK-12404.
Diffstat (limited to 'sql/catalyst')
4 files changed, 24 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index c8ee87e881..f566d1b3ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -194,7 +194,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaDate", getPath :: Nil, @@ -202,7 +202,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", getPath :: Nil, @@ -276,7 +276,7 @@ object JavaTypeInference { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", keyData :: valueData :: Nil) @@ -341,21 +341,21 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ecff860570..c1b1d5cd2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -223,7 +223,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil, @@ -231,7 +231,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil, @@ -287,7 +287,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -315,7 +315,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) @@ -548,28 +548,28 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34ec9408a..63bdf05ca7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -61,21 +61,21 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case _: DecimalType => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -172,14 +172,14 @@ object RowEncoder { case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", input :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", input :: Nil) @@ -197,7 +197,7 @@ object RowEncoder { "array", ObjectType(classOf[Array[_]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -210,7 +210,7 @@ object RowEncoder { val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 10ec75eca3..492cc9bf41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -42,16 +42,14 @@ import org.apache.spark.sql.types._ * of calling the function. */ case class StaticInvoke( - staticObject: Any, + staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, propagateNull: Boolean = true) extends Expression { - val objectName = staticObject match { - case c: Class[_] => c.getName - case other => other.getClass.getName.stripSuffix("$") - } + val objectName = staticObject.getName.stripSuffix("$") + override def nullable: Boolean = true override def children: Seq[Expression] = arguments |