From 6959061f02b02afd4cef683b5eea0b7097eedee7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jul 2016 15:33:05 +0800 Subject: [SPARK-16706][SQL] support java map in encoder ## What changes were proposed in this pull request? finish the TODO, create a new expression `ExternalMapToCatalyst` to iterate the map directly. ## How was this patch tested? new test in `JavaDatasetSuite` Author: Wenchen Fan Closes #14344 from cloud-fan/java-map. --- .../spark/sql/catalyst/JavaTypeInference.scala | 12 +- .../spark/sql/catalyst/ScalaReflection.scala | 34 ++--- .../sql/catalyst/expressions/objects/objects.scala | 158 ++++++++++++++++++++- 3 files changed, 176 insertions(+), 28 deletions(-) (limited to 'sql/catalyst/src/main') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b3a233ae39..e6f61b00eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -395,10 +395,14 @@ object JavaTypeInference { toCatalystArray(inputObject, elementType(typeToken)) case _ if mapType.isAssignableFrom(typeToken) => - // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can - // not guarantee they have same iteration order(which is different from scala map). - // A possible solution is creating a new `MapObjects` that can iterate a map directly. - throw new UnsupportedOperationException("map type is not supported currently") + val (keyType, valueType) = mapKeyValueType(typeToken) + ExternalMapToCatalyst( + inputObject, + ObjectType(keyType.getRawType), + serializerFor(_, keyType), + ObjectType(valueType.getRawType), + serializerFor(_, valueType) + ) case other => val properties = getJavaBeanProperties(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8affb033d8..76f87f64ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -472,29 +472,17 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) + val keyClsName = getClassNameFromType(keyType) + val valueClsName = getClassNameFromType(valueType) + val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath + val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath + + ExternalMapToCatalyst( + inputObject, + dataTypeFor(keyType), + serializerFor(_, keyType, keyPath), + dataTypeFor(valueType), + serializerFor(_, valueType, valuePath)) case t if t <:< localTypeOf[String] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d6863ed2fd..06589411cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -501,6 +501,162 @@ case class MapObjects private( } } +object ExternalMapToCatalyst { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + inputMap: Expression, + keyType: DataType, + keyConverter: Expression => Expression, + valueType: DataType, + valueConverter: Expression => Expression): ExternalMapToCatalyst = { + val id = curId.getAndIncrement() + val keyName = "ExternalMapToCatalyst_key" + id + val valueName = "ExternalMapToCatalyst_value" + id + val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + + ExternalMapToCatalyst( + keyName, + keyType, + keyConverter(LambdaVariable(keyName, "false", keyType)), + valueName, + valueIsNull, + valueType, + valueConverter(LambdaVariable(valueName, valueIsNull, valueType)), + inputMap + ) + } +} + +/** + * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when + * iterate the map. + * + * @param key the name of the map key variable that used when iterate the map, and used as input for + * the `keyConverter` + * @param keyType the data type of the map key variable that used when iterate the map, and used as + * input for the `keyConverter` + * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. + * @param value the name of the map value variable that used when iterate the map, and used as input + * for the `valueConverter` + * @param valueIsNull the nullability of the map value variable that used when iterate the map, and + * used as input for the `valueConverter` + * @param valueType the data type of the map value variable that used when iterate the map, and + * used as input for the `valueConverter` + * @param valueConverter A function that take the `value` as input, and converts it to catalyst + * format. + * @param child An expression that when evaluated returns the input map object. + */ +case class ExternalMapToCatalyst private( + key: String, + keyType: DataType, + keyConverter: Expression, + value: String, + valueIsNull: String, + valueType: DataType, + valueConverter: Expression, + child: Expression) + extends UnaryExpression with NonSQLExpression { + + override def foldable: Boolean = false + + override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputMap = child.genCode(ctx) + val genKeyConverter = keyConverter.genCode(ctx) + val genValueConverter = valueConverter.genCode(ctx) + val length = ctx.freshName("length") + val index = ctx.freshName("index") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedValues = ctx.freshName("convertedValues") + val entry = ctx.freshName("entry") + val entries = ctx.freshName("entries") + + val (defineEntries, defineKeyValue) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + val javaIteratorCls = classOf[java.util.Iterator[_]].getName + val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + + val defineEntries = + s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + + val defineKeyValue = + s""" + final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + """ + + defineEntries -> defineKeyValue + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val scalaIteratorCls = classOf[Iterator[_]].getName + val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + + val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + + val defineKeyValue = + s""" + final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + """ + + defineEntries -> defineKeyValue + } + + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { + s"boolean $valueIsNull = false;" + } else { + s"boolean $valueIsNull = $value == null;" + } + + val arrayCls = classOf[GenericArrayData].getName + val mapCls = classOf[ArrayBasedMapData].getName + val convertedKeyType = ctx.boxedType(keyConverter.dataType) + val convertedValueType = ctx.boxedType(valueConverter.dataType) + val code = + s""" + ${inputMap.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${inputMap.isNull}) { + final int $length = ${inputMap.value}.size(); + final Object[] $convertedKeys = new Object[$length]; + final Object[] $convertedValues = new Object[$length]; + int $index = 0; + $defineEntries + while($entries.hasNext()) { + $defineKeyValue + $valueNullCheck + + ${genKeyConverter.code} + if (${genKeyConverter.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value}; + } + + ${genValueConverter.code} + if (${genValueConverter.isNull}) { + $convertedValues[$index] = null; + } else { + $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value}; + } + + $index++; + } + + ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues)); + } + """ + ev.copy(code = code, isNull = inputMap.isNull) + } +} + /** * Constructs a new external row, using the result of evaluating the specified expressions * as content. -- cgit v1.2.3