From 6959061f02b02afd4cef683b5eea0b7097eedee7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Jul 2016 15:33:05 +0800 Subject: [SPARK-16706][SQL] support java map in encoder ## What changes were proposed in this pull request? finish the TODO, create a new expression `ExternalMapToCatalyst` to iterate the map directly. ## How was this patch tested? new test in `JavaDatasetSuite` Author: Wenchen Fan Closes #14344 from cloud-fan/java-map. --- .../spark/sql/catalyst/JavaTypeInference.scala | 12 +- .../spark/sql/catalyst/ScalaReflection.scala | 34 ++--- .../sql/catalyst/expressions/objects/objects.scala | 158 ++++++++++++++++++++- .../catalyst/encoders/ExpressionEncoderSuite.scala | 6 + .../org/apache/spark/sql/JavaDatasetSuite.java | 58 +++++++- 5 files changed, 236 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index b3a233ae39..e6f61b00eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -395,10 +395,14 @@ object JavaTypeInference { toCatalystArray(inputObject, elementType(typeToken)) case _ if mapType.isAssignableFrom(typeToken) => - // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can - // not guarantee they have same iteration order(which is different from scala map). - // A possible solution is creating a new `MapObjects` that can iterate a map directly. - throw new UnsupportedOperationException("map type is not supported currently") + val (keyType, valueType) = mapKeyValueType(typeToken) + ExternalMapToCatalyst( + inputObject, + ObjectType(keyType.getRawType), + serializerFor(_, keyType), + ObjectType(valueType.getRawType), + serializerFor(_, valueType) + ) case other => val properties = getJavaBeanProperties(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8affb033d8..76f87f64ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -472,29 +472,17 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Map[_, _]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) + val keyClsName = getClassNameFromType(keyType) + val valueClsName = getClassNameFromType(valueType) + val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath + val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath + + ExternalMapToCatalyst( + inputObject, + dataTypeFor(keyType), + serializerFor(_, keyType, keyPath), + dataTypeFor(valueType), + serializerFor(_, valueType, valuePath)) case t if t <:< localTypeOf[String] => StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d6863ed2fd..06589411cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ /** @@ -501,6 +501,162 @@ case class MapObjects private( } } +object ExternalMapToCatalyst { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + def apply( + inputMap: Expression, + keyType: DataType, + keyConverter: Expression => Expression, + valueType: DataType, + valueConverter: Expression => Expression): ExternalMapToCatalyst = { + val id = curId.getAndIncrement() + val keyName = "ExternalMapToCatalyst_key" + id + val valueName = "ExternalMapToCatalyst_value" + id + val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id + + ExternalMapToCatalyst( + keyName, + keyType, + keyConverter(LambdaVariable(keyName, "false", keyType)), + valueName, + valueIsNull, + valueType, + valueConverter(LambdaVariable(valueName, valueIsNull, valueType)), + inputMap + ) + } +} + +/** + * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when + * iterate the map. + * + * @param key the name of the map key variable that used when iterate the map, and used as input for + * the `keyConverter` + * @param keyType the data type of the map key variable that used when iterate the map, and used as + * input for the `keyConverter` + * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. + * @param value the name of the map value variable that used when iterate the map, and used as input + * for the `valueConverter` + * @param valueIsNull the nullability of the map value variable that used when iterate the map, and + * used as input for the `valueConverter` + * @param valueType the data type of the map value variable that used when iterate the map, and + * used as input for the `valueConverter` + * @param valueConverter A function that take the `value` as input, and converts it to catalyst + * format. + * @param child An expression that when evaluated returns the input map object. + */ +case class ExternalMapToCatalyst private( + key: String, + keyType: DataType, + keyConverter: Expression, + value: String, + valueIsNull: String, + valueType: DataType, + valueConverter: Expression, + child: Expression) + extends UnaryExpression with NonSQLExpression { + + override def foldable: Boolean = false + + override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputMap = child.genCode(ctx) + val genKeyConverter = keyConverter.genCode(ctx) + val genValueConverter = valueConverter.genCode(ctx) + val length = ctx.freshName("length") + val index = ctx.freshName("index") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedValues = ctx.freshName("convertedValues") + val entry = ctx.freshName("entry") + val entries = ctx.freshName("entries") + + val (defineEntries, defineKeyValue) = child.dataType match { + case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => + val javaIteratorCls = classOf[java.util.Iterator[_]].getName + val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName + + val defineEntries = + s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();" + + val defineKeyValue = + s""" + final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + """ + + defineEntries -> defineKeyValue + + case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) => + val scalaIteratorCls = classOf[Iterator[_]].getName + val scalaMapEntryCls = classOf[Tuple2[_, _]].getName + + val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();" + + val defineKeyValue = + s""" + final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); + ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); + ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + """ + + defineEntries -> defineKeyValue + } + + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { + s"boolean $valueIsNull = false;" + } else { + s"boolean $valueIsNull = $value == null;" + } + + val arrayCls = classOf[GenericArrayData].getName + val mapCls = classOf[ArrayBasedMapData].getName + val convertedKeyType = ctx.boxedType(keyConverter.dataType) + val convertedValueType = ctx.boxedType(valueConverter.dataType) + val code = + s""" + ${inputMap.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${inputMap.isNull}) { + final int $length = ${inputMap.value}.size(); + final Object[] $convertedKeys = new Object[$length]; + final Object[] $convertedValues = new Object[$length]; + int $index = 0; + $defineEntries + while($entries.hasNext()) { + $defineKeyValue + $valueNullCheck + + ${genKeyConverter.code} + if (${genKeyConverter.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value}; + } + + ${genValueConverter.code} + if (${genValueConverter.isNull}) { + $convertedValues[$index] = null; + } else { + $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value}; + } + + $index++; + } + + ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues)); + } + """ + ev.copy(code = code, isNull = inputMap.isNull) + } +} + /** * Constructs a new external row, using the result of evaluating the specified expressions * as content. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a1f9259f13..4df9062018 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -328,6 +328,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } + test("null check for map key") { + val encoder = ExpressionEncoder[Map[String, Int]]() + val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a711811f41..96e8fb0668 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -497,6 +497,8 @@ public class JavaDatasetSuite implements Serializable { private String[] d; private List e; private List f; + private Map g; + private Map, Map> h; public boolean isA() { return a; @@ -546,6 +548,22 @@ public class JavaDatasetSuite implements Serializable { this.f = f; } + public Map getG() { + return g; + } + + public void setG(Map g) { + this.g = g; + } + + public Map, Map> getH() { + return h; + } + + public void setH(Map, Map> h) { + this.h = h; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -558,7 +576,10 @@ public class JavaDatasetSuite implements Serializable { if (!Arrays.equals(c, that.c)) return false; if (!Arrays.equals(d, that.d)) return false; if (!e.equals(that.e)) return false; - return f.equals(that.f); + if (!f.equals(that.f)) return false; + if (!g.equals(that.g)) return false; + return h.equals(that.h); + } @Override @@ -569,6 +590,8 @@ public class JavaDatasetSuite implements Serializable { result = 31 * result + Arrays.hashCode(d); result = 31 * result + e.hashCode(); result = 31 * result + f.hashCode(); + result = 31 * result + g.hashCode(); + result = 31 * result + h.hashCode(); return result; } } @@ -648,6 +671,17 @@ public class JavaDatasetSuite implements Serializable { obj1.setD(new String[]{"hello", null}); obj1.setE(Arrays.asList("a", "b")); obj1.setF(Arrays.asList(100L, null, 200L)); + Map map1 = new HashMap(); + map1.put(1, "a"); + map1.put(2, "b"); + obj1.setG(map1); + Map nestedMap1 = new HashMap(); + nestedMap1.put("x", "1"); + nestedMap1.put("y", "2"); + Map, Map> complexMap1 = new HashMap<>(); + complexMap1.put(Arrays.asList(1L, 2L), nestedMap1); + obj1.setH(complexMap1); + SimpleJavaBean obj2 = new SimpleJavaBean(); obj2.setA(false); obj2.setB(30); @@ -655,6 +689,16 @@ public class JavaDatasetSuite implements Serializable { obj2.setD(new String[]{null, "world"}); obj2.setE(Arrays.asList("x", "y")); obj2.setF(Arrays.asList(300L, null, 400L)); + Map map2 = new HashMap(); + map2.put(3, "c"); + map2.put(4, "d"); + obj2.setG(map2); + Map nestedMap2 = new HashMap(); + nestedMap2.put("q", "1"); + nestedMap2.put("w", "2"); + Map, Map> complexMap2 = new HashMap<>(); + complexMap2.put(Arrays.asList(3L, 4L), nestedMap2); + obj2.setH(complexMap2); List data = Arrays.asList(obj1, obj2); Dataset ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); @@ -673,21 +717,27 @@ public class JavaDatasetSuite implements Serializable { new byte[]{1, 2}, new String[]{"hello", null}, Arrays.asList("a", "b"), - Arrays.asList(100L, null, 200L)}); + Arrays.asList(100L, null, 200L), + map1, + complexMap1}); Row row2 = new GenericRow(new Object[]{ false, 30, new byte[]{3, 4}, new String[]{null, "world"}, Arrays.asList("x", "y"), - Arrays.asList(300L, null, 400L)}); + Arrays.asList(300L, null, 400L), + map2, + complexMap2}); StructType schema = new StructType() .add("a", BooleanType, false) .add("b", IntegerType, false) .add("c", BinaryType) .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) - .add("f", createArrayType(LongType)); + .add("f", createArrayType(LongType)) + .add("g", createMapType(IntegerType, StringType)) + .add("h",createMapType(createArrayType(LongType), createMapType(StringType, StringType))); Dataset ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); -- cgit v1.2.3