aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-07-26 15:33:05 +0800
committerCheng Lian <lian@databricks.com>2016-07-26 15:33:05 +0800
commit6959061f02b02afd4cef683b5eea0b7097eedee7 (patch)
tree65f9e5ea5ba0866e59d55c9d799423e78bcd1f1c /sql
parent7b06a8948fc16d3c14e240fdd632b79ce1651008 (diff)
downloadspark-6959061f02b02afd4cef683b5eea0b7097eedee7.tar.gz
spark-6959061f02b02afd4cef683b5eea0b7097eedee7.tar.bz2
spark-6959061f02b02afd4cef683b5eea0b7097eedee7.zip
[SPARK-16706][SQL] support java map in encoder
## What changes were proposed in this pull request? finish the TODO, create a new expression `ExternalMapToCatalyst` to iterate the map directly. ## How was this patch tested? new test in `JavaDatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #14344 from cloud-fan/java-map.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala158
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala6
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java58
5 files changed, 236 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index b3a233ae39..e6f61b00eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -395,10 +395,14 @@ object JavaTypeInference {
toCatalystArray(inputObject, elementType(typeToken))
case _ if mapType.isAssignableFrom(typeToken) =>
- // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can
- // not guarantee they have same iteration order(which is different from scala map).
- // A possible solution is creating a new `MapObjects` that can iterate a map directly.
- throw new UnsupportedOperationException("map type is not supported currently")
+ val (keyType, valueType) = mapKeyValueType(typeToken)
+ ExternalMapToCatalyst(
+ inputObject,
+ ObjectType(keyType.getRawType),
+ serializerFor(_, keyType),
+ ObjectType(valueType.getRawType),
+ serializerFor(_, valueType)
+ )
case other =>
val properties = getJavaBeanProperties(other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8affb033d8..76f87f64ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -472,29 +472,17 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val keys =
- Invoke(
- Invoke(inputObject, "keysIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = toCatalystArray(keys, keyType)
-
- val values =
- Invoke(
- Invoke(inputObject, "valuesIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = toCatalystArray(values, valueType)
-
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- NewInstance(
- classOf[ArrayBasedMapData],
- convertedKeys :: convertedValues :: Nil,
- dataType = MapType(keyDataType, valueDataType, valueNullable))
+ val keyClsName = getClassNameFromType(keyType)
+ val valueClsName = getClassNameFromType(valueType)
+ val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath
+ val valuePath = s"""- map value class: "$valueClsName"""" +: walkedTypePath
+
+ ExternalMapToCatalyst(
+ inputObject,
+ dataTypeFor(keyType),
+ serializerFor(_, keyType, keyPath),
+ dataTypeFor(valueType),
+ serializerFor(_, valueType, valuePath))
case t if t <:< localTypeOf[String] =>
StaticInvoke(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d6863ed2fd..06589411cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
/**
@@ -501,6 +501,162 @@ case class MapObjects private(
}
}
+object ExternalMapToCatalyst {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+ def apply(
+ inputMap: Expression,
+ keyType: DataType,
+ keyConverter: Expression => Expression,
+ valueType: DataType,
+ valueConverter: Expression => Expression): ExternalMapToCatalyst = {
+ val id = curId.getAndIncrement()
+ val keyName = "ExternalMapToCatalyst_key" + id
+ val valueName = "ExternalMapToCatalyst_value" + id
+ val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id
+
+ ExternalMapToCatalyst(
+ keyName,
+ keyType,
+ keyConverter(LambdaVariable(keyName, "false", keyType)),
+ valueName,
+ valueIsNull,
+ valueType,
+ valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
+ inputMap
+ )
+ }
+}
+
+/**
+ * Converts a Scala/Java map object into catalyst format, by applying the key/value converter when
+ * iterate the map.
+ *
+ * @param key the name of the map key variable that used when iterate the map, and used as input for
+ * the `keyConverter`
+ * @param keyType the data type of the map key variable that used when iterate the map, and used as
+ * input for the `keyConverter`
+ * @param keyConverter A function that take the `key` as input, and converts it to catalyst format.
+ * @param value the name of the map value variable that used when iterate the map, and used as input
+ * for the `valueConverter`
+ * @param valueIsNull the nullability of the map value variable that used when iterate the map, and
+ * used as input for the `valueConverter`
+ * @param valueType the data type of the map value variable that used when iterate the map, and
+ * used as input for the `valueConverter`
+ * @param valueConverter A function that take the `value` as input, and converts it to catalyst
+ * format.
+ * @param child An expression that when evaluated returns the input map object.
+ */
+case class ExternalMapToCatalyst private(
+ key: String,
+ keyType: DataType,
+ keyConverter: Expression,
+ value: String,
+ valueIsNull: String,
+ valueType: DataType,
+ valueConverter: Expression,
+ child: Expression)
+ extends UnaryExpression with NonSQLExpression {
+
+ override def foldable: Boolean = false
+
+ override def dataType: MapType = MapType(keyConverter.dataType, valueConverter.dataType)
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val inputMap = child.genCode(ctx)
+ val genKeyConverter = keyConverter.genCode(ctx)
+ val genValueConverter = valueConverter.genCode(ctx)
+ val length = ctx.freshName("length")
+ val index = ctx.freshName("index")
+ val convertedKeys = ctx.freshName("convertedKeys")
+ val convertedValues = ctx.freshName("convertedValues")
+ val entry = ctx.freshName("entry")
+ val entries = ctx.freshName("entries")
+
+ val (defineEntries, defineKeyValue) = child.dataType match {
+ case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) =>
+ val javaIteratorCls = classOf[java.util.Iterator[_]].getName
+ val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName
+
+ val defineEntries =
+ s"final $javaIteratorCls $entries = ${inputMap.value}.entrySet().iterator();"
+
+ val defineKeyValue =
+ s"""
+ final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
+ ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey();
+ ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue();
+ """
+
+ defineEntries -> defineKeyValue
+
+ case ObjectType(cls) if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
+ val scalaIteratorCls = classOf[Iterator[_]].getName
+ val scalaMapEntryCls = classOf[Tuple2[_, _]].getName
+
+ val defineEntries = s"final $scalaIteratorCls $entries = ${inputMap.value}.iterator();"
+
+ val defineKeyValue =
+ s"""
+ final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
+ ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1();
+ ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2();
+ """
+
+ defineEntries -> defineKeyValue
+ }
+
+ val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
+ s"boolean $valueIsNull = false;"
+ } else {
+ s"boolean $valueIsNull = $value == null;"
+ }
+
+ val arrayCls = classOf[GenericArrayData].getName
+ val mapCls = classOf[ArrayBasedMapData].getName
+ val convertedKeyType = ctx.boxedType(keyConverter.dataType)
+ val convertedValueType = ctx.boxedType(valueConverter.dataType)
+ val code =
+ s"""
+ ${inputMap.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${inputMap.isNull}) {
+ final int $length = ${inputMap.value}.size();
+ final Object[] $convertedKeys = new Object[$length];
+ final Object[] $convertedValues = new Object[$length];
+ int $index = 0;
+ $defineEntries
+ while($entries.hasNext()) {
+ $defineKeyValue
+ $valueNullCheck
+
+ ${genKeyConverter.code}
+ if (${genKeyConverter.isNull}) {
+ throw new RuntimeException("Cannot use null as map key!");
+ } else {
+ $convertedKeys[$index] = ($convertedKeyType) ${genKeyConverter.value};
+ }
+
+ ${genValueConverter.code}
+ if (${genValueConverter.isNull}) {
+ $convertedValues[$index] = null;
+ } else {
+ $convertedValues[$index] = ($convertedValueType) ${genValueConverter.value};
+ }
+
+ $index++;
+ }
+
+ ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new $arrayCls($convertedValues));
+ }
+ """
+ ev.copy(code = code, isNull = inputMap.isNull)
+ }
+}
+
/**
* Constructs a new external row, using the result of evaluating the specified expressions
* as content.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index a1f9259f13..4df9062018 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -328,6 +328,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
}
+ test("null check for map key") {
+ val encoder = ExpressionEncoder[Map[String, Int]]()
+ val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2))))
+ assert(e.getMessage.contains("Cannot use null as map key"))
+ }
+
private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a711811f41..96e8fb0668 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -497,6 +497,8 @@ public class JavaDatasetSuite implements Serializable {
private String[] d;
private List<String> e;
private List<Long> f;
+ private Map<Integer, String> g;
+ private Map<List<Long>, Map<String, String>> h;
public boolean isA() {
return a;
@@ -546,6 +548,22 @@ public class JavaDatasetSuite implements Serializable {
this.f = f;
}
+ public Map<Integer, String> getG() {
+ return g;
+ }
+
+ public void setG(Map<Integer, String> g) {
+ this.g = g;
+ }
+
+ public Map<List<Long>, Map<String, String>> getH() {
+ return h;
+ }
+
+ public void setH(Map<List<Long>, Map<String, String>> h) {
+ this.h = h;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
@@ -558,7 +576,10 @@ public class JavaDatasetSuite implements Serializable {
if (!Arrays.equals(c, that.c)) return false;
if (!Arrays.equals(d, that.d)) return false;
if (!e.equals(that.e)) return false;
- return f.equals(that.f);
+ if (!f.equals(that.f)) return false;
+ if (!g.equals(that.g)) return false;
+ return h.equals(that.h);
+
}
@Override
@@ -569,6 +590,8 @@ public class JavaDatasetSuite implements Serializable {
result = 31 * result + Arrays.hashCode(d);
result = 31 * result + e.hashCode();
result = 31 * result + f.hashCode();
+ result = 31 * result + g.hashCode();
+ result = 31 * result + h.hashCode();
return result;
}
}
@@ -648,6 +671,17 @@ public class JavaDatasetSuite implements Serializable {
obj1.setD(new String[]{"hello", null});
obj1.setE(Arrays.asList("a", "b"));
obj1.setF(Arrays.asList(100L, null, 200L));
+ Map<Integer, String> map1 = new HashMap<Integer, String>();
+ map1.put(1, "a");
+ map1.put(2, "b");
+ obj1.setG(map1);
+ Map<String, String> nestedMap1 = new HashMap<String, String>();
+ nestedMap1.put("x", "1");
+ nestedMap1.put("y", "2");
+ Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>();
+ complexMap1.put(Arrays.asList(1L, 2L), nestedMap1);
+ obj1.setH(complexMap1);
+
SimpleJavaBean obj2 = new SimpleJavaBean();
obj2.setA(false);
obj2.setB(30);
@@ -655,6 +689,16 @@ public class JavaDatasetSuite implements Serializable {
obj2.setD(new String[]{null, "world"});
obj2.setE(Arrays.asList("x", "y"));
obj2.setF(Arrays.asList(300L, null, 400L));
+ Map<Integer, String> map2 = new HashMap<Integer, String>();
+ map2.put(3, "c");
+ map2.put(4, "d");
+ obj2.setG(map2);
+ Map<String, String> nestedMap2 = new HashMap<String, String>();
+ nestedMap2.put("q", "1");
+ nestedMap2.put("w", "2");
+ Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>();
+ complexMap2.put(Arrays.asList(3L, 4L), nestedMap2);
+ obj2.setH(complexMap2);
List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class));
@@ -673,21 +717,27 @@ public class JavaDatasetSuite implements Serializable {
new byte[]{1, 2},
new String[]{"hello", null},
Arrays.asList("a", "b"),
- Arrays.asList(100L, null, 200L)});
+ Arrays.asList(100L, null, 200L),
+ map1,
+ complexMap1});
Row row2 = new GenericRow(new Object[]{
false,
30,
new byte[]{3, 4},
new String[]{null, "world"},
Arrays.asList("x", "y"),
- Arrays.asList(300L, null, 400L)});
+ Arrays.asList(300L, null, 400L),
+ map2,
+ complexMap2});
StructType schema = new StructType()
.add("a", BooleanType, false)
.add("b", IntegerType, false)
.add("c", BinaryType)
.add("d", createArrayType(StringType))
.add("e", createArrayType(StringType))
- .add("f", createArrayType(LongType));
+ .add("f", createArrayType(LongType))
+ .add("g", createMapType(IntegerType, StringType))
+ .add("h",createMapType(createArrayType(LongType), createMapType(StringType, StringType)));
Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema)
.as(Encoders.bean(SimpleJavaBean.class));
Assert.assertEquals(data, ds3.collectAsList());