aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-01 10:35:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-01 10:35:12 -0800
commitfd95eeaf491809c6bb0f83d46b37b5e2eebbcbca (patch)
tree61ce648ac4d3f3a8c1e3065bad8c6832513fcbe4 /sql/catalyst
parent9df24624afedd993a39ab46c8211ae153aedef1a (diff)
downloadspark-fd95eeaf491809c6bb0f83d46b37b5e2eebbcbca.tar.gz
spark-fd95eeaf491809c6bb0f83d46b37b5e2eebbcbca.tar.bz2
spark-fd95eeaf491809c6bb0f83d46b37b5e2eebbcbca.zip
[SPARK-11954][SQL] Encoder for JavaBeans
create java version of `constructorFor` and `extractorFor` in `JavaTypeInference` Author: Wenchen Fan <wenchen@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #9937 from cloud-fan/pojo.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala313
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala25
8 files changed, 438 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 03aa25eda8..c40061ae0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -98,6 +98,24 @@ object Encoders {
def STRING: Encoder[java.lang.String] = ExpressionEncoder()
/**
+ * Creates an encoder for Java Bean of type T.
+ *
+ * T must be publicly accessible.
+ *
+ * supported types for java bean field:
+ * - primitive types: boolean, int, double, etc.
+ * - boxed types: Boolean, Integer, Double, etc.
+ * - String
+ * - java.math.BigDecimal
+ * - time related: java.sql.Date, java.sql.Timestamp
+ * - collection types: only array and java.util.List currently, map support is in progress
+ * - nested java bean.
+ *
+ * @since 1.6.0
+ */
+ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
+
+ /**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
* This encoder maps T into a single byte array (binary) field.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 7d4cfbe6fa..c8ee87e881 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -17,14 +17,20 @@
package org.apache.spark.sql.catalyst
-import java.beans.Introspector
+import java.beans.{PropertyDescriptor, Introspector}
import java.lang.{Iterable => JIterable}
-import java.util.{Iterator => JIterator, Map => JMap}
+import java.util.{Iterator => JIterator, Map => JMap, List => JList}
import scala.language.existentials
import com.google.common.reflect.TypeToken
+
import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
+import org.apache.spark.unsafe.types.UTF8String
+
/**
* Type-inference utilities for POJOs and Java collections.
@@ -33,13 +39,14 @@ object JavaTypeInference {
private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
+ private val listType = TypeToken.of(classOf[JList[_]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
/**
- * Infers the corresponding SQL data type of a JavaClean class.
+ * Infers the corresponding SQL data type of a JavaBean class.
* @param beanClass Java type
* @return (SQL data type, nullable)
*/
@@ -58,6 +65,8 @@ object JavaTypeInference {
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+ case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true)
+
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
@@ -87,15 +96,14 @@ object JavaTypeInference {
(ArrayType(dataType, nullable), true)
case _ if mapType.isAssignableFrom(typeToken) =>
- val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
- val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
- val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
- val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
+ val (keyType, valueType) = mapKeyValueType(typeToken)
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)
case _ =>
+ // TODO: we should only collect properties that have getter and setter. However, some tests
+ // pass in scala case class as java bean class which doesn't have getter and setter.
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
@@ -107,11 +115,294 @@ object JavaTypeInference {
}
}
+ private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
+ val beanInfo = Introspector.getBeanInfo(beanClass)
+ beanInfo.getPropertyDescriptors
+ .filter(p => p.getReadMethod != null && p.getWriteMethod != null)
+ }
+
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
- val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
- val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
- val itemType = iteratorType.resolveType(nextReturnType)
- itemType
+ val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]])
+ val iteratorType = iterableSuperType.resolveType(iteratorReturnType)
+ iteratorType.resolveType(nextReturnType)
+ }
+
+ private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = {
+ val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
+ val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]])
+ val keyType = elementType(mapSuperType.resolveType(keySetReturnType))
+ val valueType = elementType(mapSuperType.resolveType(valuesReturnType))
+ keyType -> valueType
+ }
+
+ /**
+ * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping
+ * to a native type, an ObjectType is returned.
+ *
+ * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type
+ * system. As a result, ObjectType will be returned for things like boxed Integers.
+ */
+ private def inferExternalType(cls: Class[_]): DataType = cls match {
+ case c if c == java.lang.Boolean.TYPE => BooleanType
+ case c if c == java.lang.Byte.TYPE => ByteType
+ case c if c == java.lang.Short.TYPE => ShortType
+ case c if c == java.lang.Integer.TYPE => IntegerType
+ case c if c == java.lang.Long.TYPE => LongType
+ case c if c == java.lang.Float.TYPE => FloatType
+ case c if c == java.lang.Double.TYPE => DoubleType
+ case c if c == classOf[Array[Byte]] => BinaryType
+ case _ => ObjectType(cls)
+ }
+
+ /**
+ * Returns an expression that can be used to construct an object of java bean `T` given an input
+ * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * of the same name as the constructor arguments. Nested classes will have their fields accessed
+ * using UnresolvedExtractValue.
+ */
+ def constructorFor(beanClass: Class[_]): Expression = {
+ constructorFor(TypeToken.of(beanClass), None)
+ }
+
+ private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
+ /** Returns the current path with a sub-field extracted. */
+ def addToPath(part: String): Expression = path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+
+ /** Returns the current path or `BoundReference`. */
+ def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
+
+ typeToken.getRawType match {
+ case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
+
+ case c if c == classOf[java.lang.Short] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Integer] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Long] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Double] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Byte] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Float] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+ case c if c == classOf[java.lang.Boolean] =>
+ NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
+
+ case c if c == classOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(c),
+ "toJavaDate",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case c if c == classOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(c),
+ "toJavaTimestamp",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case c if c == classOf[java.lang.String] =>
+ Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+ case c if c == classOf[java.math.BigDecimal] =>
+ Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+ case c if c.isArray =>
+ val elementType = c.getComponentType
+ val primitiveMethod = elementType match {
+ case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray")
+ case c if c == java.lang.Byte.TYPE => Some("toByteArray")
+ case c if c == java.lang.Short.TYPE => Some("toShortArray")
+ case c if c == java.lang.Integer.TYPE => Some("toIntArray")
+ case c if c == java.lang.Long.TYPE => Some("toLongArray")
+ case c if c == java.lang.Float.TYPE => Some("toFloatArray")
+ case c if c == java.lang.Double.TYPE => Some("toDoubleArray")
+ case _ => None
+ }
+
+ primitiveMethod.map { method =>
+ Invoke(getPath, method, ObjectType(c))
+ }.getOrElse {
+ Invoke(
+ MapObjects(
+ p => constructorFor(typeToken.getComponentType, Some(p)),
+ getPath,
+ inferDataType(elementType)._1),
+ "array",
+ ObjectType(c))
+ }
+
+ case c if listType.isAssignableFrom(typeToken) =>
+ val et = elementType(typeToken)
+ val array =
+ Invoke(
+ MapObjects(
+ p => constructorFor(et, Some(p)),
+ getPath,
+ inferDataType(et)._1),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
+
+ case _ if mapType.isAssignableFrom(typeToken) =>
+ val (keyType, valueType) = mapKeyValueType(typeToken)
+ val keyDataType = inferDataType(keyType)._1
+ val valueDataType = inferDataType(valueType)._1
+
+ val keyData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(keyType, Some(p)),
+ Invoke(getPath, "keyArray", ArrayType(keyDataType)),
+ keyDataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ val valueData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(valueType, Some(p)),
+ Invoke(getPath, "valueArray", ArrayType(valueDataType)),
+ valueDataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ ArrayBasedMapData,
+ ObjectType(classOf[JMap[_, _]]),
+ "toJavaMap",
+ keyData :: valueData :: Nil)
+
+ case other =>
+ val properties = getJavaBeanProperties(other)
+ assert(properties.length > 0)
+
+ val setters = properties.map { p =>
+ val fieldName = p.getName
+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
+ p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName)))
+ }.toMap
+
+ val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
+ val result = InitializeJavaBean(newInstance, setters)
+
+ if (path.nonEmpty) {
+ expressions.If(
+ IsNull(getPath),
+ expressions.Literal.create(null, ObjectType(other)),
+ result
+ )
+ } else {
+ result
+ }
+ }
+ }
+
+ /**
+ * Returns expressions for extracting all the fields from the given type.
+ */
+ def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
+ val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
+ extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+ }
+
+ private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
+
+ def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
+ val (dataType, nullable) = inferDataType(elementType)
+ if (ScalaReflection.isNativeType(dataType)) {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(dataType, nullable))
+ } else {
+ MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
+ }
+ }
+
+ if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+ inputObject
+ } else {
+ typeToken.getRawType match {
+ case c if c == classOf[String] =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil)
+
+ case c if c == classOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils,
+ TimestampType,
+ "fromJavaTimestamp",
+ inputObject :: Nil)
+
+ case c if c == classOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils,
+ DateType,
+ "fromJavaDate",
+ inputObject :: Nil)
+
+ case c if c == classOf[java.math.BigDecimal] =>
+ StaticInvoke(
+ Decimal,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case c if c == classOf[java.lang.Boolean] =>
+ Invoke(inputObject, "booleanValue", BooleanType)
+ case c if c == classOf[java.lang.Byte] =>
+ Invoke(inputObject, "byteValue", ByteType)
+ case c if c == classOf[java.lang.Short] =>
+ Invoke(inputObject, "shortValue", ShortType)
+ case c if c == classOf[java.lang.Integer] =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case c if c == classOf[java.lang.Long] =>
+ Invoke(inputObject, "longValue", LongType)
+ case c if c == classOf[java.lang.Float] =>
+ Invoke(inputObject, "floatValue", FloatType)
+ case c if c == classOf[java.lang.Double] =>
+ Invoke(inputObject, "doubleValue", DoubleType)
+
+ case _ if typeToken.isArray =>
+ toCatalystArray(inputObject, typeToken.getComponentType)
+
+ case _ if listType.isAssignableFrom(typeToken) =>
+ toCatalystArray(inputObject, elementType(typeToken))
+
+ case _ if mapType.isAssignableFrom(typeToken) =>
+ // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
+ // not guarantee they have same iteration order(which is different from scala map).
+ // A possible solution is creating a new `MapObjects` that can iterate a map directly.
+ throw new UnsupportedOperationException("map type is not supported currently")
+
+ case other =>
+ val properties = getJavaBeanProperties(other)
+ if (properties.length > 0) {
+ CreateNamedStruct(properties.flatMap { p =>
+ val fieldName = p.getName
+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
+ val fieldValue = Invoke(
+ inputObject,
+ p.getReadMethod.getName,
+ inferExternalType(fieldType.getRawType))
+ expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ })
+ } else {
+ throw new UnsupportedOperationException(s"no encoder found for ${other.getName}")
+ }
+ }
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 06ffe86455..3e8420ecb9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -29,8 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection}
import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
/**
@@ -68,6 +67,22 @@ object ExpressionEncoder {
ClassTag[T](cls))
}
+ // TODO: improve error message for java bean encoder.
+ def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
+ val schema = JavaTypeInference.inferDataType(beanClass)._1
+ assert(schema.isInstanceOf[StructType])
+
+ val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
+ val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
+
+ new ExpressionEncoder[T](
+ schema.asInstanceOf[StructType],
+ flat = false,
+ toRowExpression.flatten,
+ fromRowExpression,
+ ClassTag[T](beanClass))
+ }
+
/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should be unresolved so that information about
@@ -216,7 +231,7 @@ case class ExpressionEncoder[T](
*/
def assertUnresolved(): Unit = {
(fromRowExpression +: toRowExpressions).foreach(_.foreach {
- case a: AttributeReference =>
+ case a: AttributeReference if a.name != "loopVar" =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
})
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 62d09f0f55..e6ab9a31be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
* is expressed using catalyst expressions.
*
- * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
+ * The following collection ObjectTypes are currently supported:
+ * Seq, Array, ArrayData, java.util.List
*
* @param function A function that returns an expression, given an attribute that can be used
* to access the current value. This is does as a lambda function so that
@@ -386,6 +387,8 @@ case class MapObjects(
(".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false)
+ case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ (".size()", (i: String) => s".get($i)", false)
case ArrayType(t, _) =>
val (sqlType, primitiveElement) = t match {
case m: MapType => (m, false)
@@ -596,3 +599,40 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
override def dataType: DataType = ObjectType(tag.runtimeClass)
}
+
+/**
+ * Initialize a Java Bean instance by setting its field values via setters.
+ */
+case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression])
+ extends Expression {
+
+ override def nullable: Boolean = beanInstance.nullable
+ override def children: Seq[Expression] = beanInstance +: setters.values.toSeq
+ override def dataType: DataType = beanInstance.dataType
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val instanceGen = beanInstance.gen(ctx)
+
+ val initialize = setters.map {
+ case (setterMethod, fieldValue) =>
+ val fieldGen = fieldValue.gen(ctx)
+ s"""
+ ${fieldGen.code}
+ ${instanceGen.value}.$setterMethod(${fieldGen.value});
+ """
+ }
+
+ ev.isNull = instanceGen.isNull
+ ev.value = instanceGen.value
+
+ s"""
+ ${instanceGen.code}
+ if (!${instanceGen.isNull}) {
+ ${initialize.mkString("\n")}
+ }
+ """
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 35f087bacc..f1cea07976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.trees
+import scala.collection.Map
+
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.types.{StructType, DataType}
@@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case nonChild: AnyRef => nonChild
case null => null
}
+ case m: Map[_, _] => m.mapValues {
+ case arg: TreeNode[_] if containsChild(arg) =>
+ val newChild = remainingNewChildren.remove(0)
+ val oldChild = remainingOldChildren.remove(0)
+ if (newChild fastEquals oldChild) {
+ oldChild
+ } else {
+ changed = true
+ newChild
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
+ }.view.force // `mapValues` is lazy and we need to force it to materialize
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = remainingNewChildren.remove(0)
val oldChild = remainingOldChildren.remove(0)
@@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
Some(arg)
}
- case m: Map[_, _] => m
+ case m: Map[_, _] => m.mapValues {
+ case arg: TreeNode[_] if containsChild(arg) =>
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case other => other
+ }.view.force // `mapValues` is lazy and we need to force it to materialize
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
index 70b028d2b3..d85b72ed83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala
@@ -70,4 +70,9 @@ object ArrayBasedMapData {
def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
keys.zip(values).toMap
}
+
+ def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = {
+ import scala.collection.JavaConverters._
+ keys.zip(values).toMap.asJava
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 96588bb5dc..2b8cdc1e23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.util
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, Decimal}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(seq: Seq[Any]) = this(seq.toArray)
+ def this(list: java.util.List[Any]) = this(list.asScala)
// TODO: This is boxing. We should specialize.
def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 8fff39906b..965bdb1515 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}
+case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
+ override def children: Seq[Expression] = map.values.toSeq
+ override def nullable: Boolean = true
+ override def dataType: NullType = NullType
+ override lazy val resolved = true
+}
+
class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite {
val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
assert(expected === actual)
}
+
+ test("expressions inside a map") {
+ val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2)))
+
+ {
+ val actual = expression.transform {
+ case Literal(i: Int, _) => Literal(i + 1)
+ }
+ val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
+ assert(actual === expected)
+ }
+
+ {
+ val actual = expression.withNewChildren(Seq(Literal(2), Literal(3)))
+ val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3)))
+ assert(actual === expected)
+ }
+ }
}