aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-07-26 15:33:05 +0800
committerCheng Lian <lian@databricks.com>2016-07-26 15:33:05 +0800
commit6959061f02b02afd4cef683b5eea0b7097eedee7 (patch)
tree65f9e5ea5ba0866e59d55c9d799423e78bcd1f1c /sql/catalyst/src/main
parent7b06a8948fc16d3c14e240fdd632b79ce1651008 (diff)
downloadspark-6959061f02b02afd4cef683b5eea0b7097eedee7.tar.gz
spark-6959061f02b02afd4cef683b5eea0b7097eedee7.tar.bz2
spark-6959061f02b02afd4cef683b5eea0b7097eedee7.zip
[SPARK-16706][SQL] support java map in encoder
## What changes were proposed in this pull request? finish the TODO, create a new expression `ExternalMapToCatalyst` to iterate the map directly. ## How was this patch tested? new test in `JavaDatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #14344 from cloud-fan/java-map.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala158
3 files changed, 176 insertions, 28 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index b3a233ae39..e6f61b00eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -395,10 +395,14 @@ object JavaTypeInference {
toCatalystArray(inputObject, elementType(typeToken))
case _ if mapType.isAssignableFrom(typeToken) =>
- // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
- // not guarantee they have same iteration order(which is different from scala map).
- // A possible solution is creating a new `MapObjects` that can iterate a map directly.
- throw new UnsupportedOperationException("map type is not supported currently")
+ val (keyType, valueType) = mapKeyValueType(typeToken)
+ ExternalMapToCatalyst(
+ inputObject,
+ ObjectType(keyType.getRawType),
+ serializerFor(_, keyType),
+ ObjectType(valueType.getRawType),
+ serializerFor(_, valueType)
+ )
case other =>
val properties = getJavaBeanProperties(other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8affb033d8..76f87f64ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -472,29 +472,17 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val keys =
- Invoke(
- Invoke(inputObject, "keysIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = toCatalystArray(keys, keyType)
-
- val values =
- Invoke(
- Invoke(inputObject, "valuesIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = toCatalystArray(values, valueType)
-
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- NewInstance(
- classOf[ArrayBasedMapData],
- convertedKeys :: convertedValues :: Nil,
- dataType = MapType(keyDataType, valueDataType, valueNullable))
+ val keyClsName = getClassNameFromType(keyType)
+ val valueClsName = getClassNameFromType(valueType)
+ val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath
+ val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath
+
+ ExternalMapToCatalyst(
+ inputObject,
+ dataTypeFor(keyType),
+ serializerFor(_, keyType, keyPath),
+ dataTypeFor(valueType),
+ serializerFor(_, valueType, valuePath))
case t if t <:< localTypeOf[String] =>
StaticInvoke(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d6863ed2fd..06589411cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
/**
@@ -501,6 +501,162 @@ case class MapObjects private(
}
}
+object ExternalMapToCatalyst {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ def apply(
+ inputMap: Expression,
+ keyType: DataType,
+ keyConverter: Expression => Expression,
+ valueType: DataType,
+ valueConverter: Expression => Expression): ExternalMapToCatalyst = {
+ val id = curId.getAndIncrement()
+ val keyName = "ExternalMapToCatalyst_key" + id
+ val valueName = "ExternalMapToCatalyst_value" + id
+ val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id
+
+ ExternalMapToCatalyst(
+ keyName,
+ keyType,
+ keyConverter(LambdaVariable(keyName, "false", keyType)),
+ valueName,
+ valueIsNull,
+ valueType,
+ valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
+ inputMap
+ )
+ }
+}
+
+/**
+ * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when
+ * iterate the map.
+ *
+ * @param key the name of the map key variable that used when iterate the map, and used as input for
+ * the `keyConverter`
+ * @param keyType the data type of the map key variable that used when iterate the map, and used as
+ * input for the `keyConverter`
+ * @param keyConverter A function that take the `key` as input, and converts it to catalyst format.
+ * @param value the name of the map value variable that used when iterate the map, and used as input
+ * for the `valueConverter`
+ * @param valueIsNull the nullability of the map value variable that used when iterate the map, and
+ * used as input for the `valueConverter`
+ * @param valueType the data type of the map value variable that used when iterate the map, and
+ * used as input for the `valueConverter`
+ * @param valueConverter A function that take the `value` as input, and converts it to catalyst
+ * format.
+ * @param child An expression that when evaluated returns the input map object.
+ */
+case class ExternalMapToCatalyst private(
+ key: String,
+ keyType: DataType,
+ keyConverter: Expression,
+ value: String,
+ valueIsNull: String,
+ valueType: DataType,
+ valueConverter: Expression,
+ child: Expression)
+ extends UnaryExpression with NonSQLExpression {
+
+ override def foldable: Boolean = false
+
+ override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType)
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val inputMap = child.genCode(ctx)
+ val genKeyConverter = keyConverter.genCode(ctx)
+ val genValueConverter = valueConverter.genCode(ctx)
+ val length = ctx.freshName("length")
+ val index = ctx.freshName("index")
+ val convertedKeys = ctx.freshName("convertedKeys")
+ val convertedValues = ctx.freshName("convertedValues")
+ val entry = ctx.freshName("entry")
+ val entries = ctx.freshName("entries")
+
+ val (defineEntries, defineKeyValue) = child.dataType match {
+ case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
+ val javaIteratorCls = classOf[java.util.Iterator[_]].getName
+ val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName
+
+ val defineEntries =
+ s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();"
+
+ val defineKeyValue =
+ s"""
+ final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
+ ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey();
+ ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue();
+ """
+
+ defineEntries -> defineKeyValue
+
+ case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
+ val scalaIteratorCls = classOf[Iterator[_]].getName
+ val scalaMapEntryCls = classOf[Tuple2[_, _]].getName
+
+ val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();"
+
+ val defineKeyValue =
+ s"""
+ final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
+ ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1();
+ ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2();
+ """
+
+ defineEntries -> defineKeyValue
+ }
+
+ val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
+ s"boolean $valueIsNull = false;"
+ } else {
+ s"boolean $valueIsNull = $value == null;"
+ }
+
+ val arrayCls = classOf[GenericArrayData].getName
+ val mapCls = classOf[ArrayBasedMapData].getName
+ val convertedKeyType = ctx.boxedType(keyConverter.dataType)
+ val convertedValueType = ctx.boxedType(valueConverter.dataType)
+ val code =
+ s"""
+ ${inputMap.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${inputMap.isNull}) {
+ final int $length = ${inputMap.value}.size();
+ final Object[] $convertedKeys = new Object[$length];
+ final Object[] $convertedValues = new Object[$length];
+ int $index = 0;
+ $defineEntries
+ while($entries.hasNext()) {
+ $defineKeyValue
+ $valueNullCheck
+
+ ${genKeyConverter.code}
+ if (${genKeyConverter.isNull}) {
+ throw new RuntimeException("Cannot use null as map key!");
+ } else {
+ $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value};
+ }
+
+ ${genValueConverter.code}
+ if (${genValueConverter.isNull}) {
+ $convertedValues[$index] = null;
+ } else {
+ $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value};
+ }
+
+ $index++;
+ }
+
+ ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues));
+ }
+ """
+ ev.copy(code = code, isNull = inputMap.isNull)
+ }
+}
+
/**
* Constructs a new external row, using the result of evaluating the specified expressions
* as content.