aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-10-21 11:06:34 -0700
committerMichael Armbrust <michael@databricks.com>2015-10-21 11:06:34 -0700
commitccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0 (patch)
treeeb1a3bedb22b84209fe9e7a2888f58eb53d717a0 /sql
parentf62e3260889d67256d335fd0dd38f114ae4e3eca (diff)
downloadspark-ccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0.tar.gz
spark-ccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0.tar.bz2
spark-ccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0.zip
[SPARK-11216] [SQL] add encoder/decoder for external row
Implement encode/decode for external row based on `ClassEncoder`. TODO: * code cleanup * ~~fix corner cases~~ * refactor the encoder interface * improve test for product codegen, to cover more corner cases. Author: Wenchen Fan <wenchen@databricks.com> Closes #9184 from cloud-fan/encoder.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala234
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala96
9 files changed, 459 insertions, 54 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8edd6498e5..27c96f4122 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
/**
* A default version of ScalaReflection that uses the runtime universe.
@@ -142,7 +142,7 @@ trait ScalaReflection {
}
/**
- * Returns an expression that can be used to construct an object of type `T` given a an input
+ * Returns an expression that can be used to construct an object of type `T` given an input
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
new file mode 100644
index 0000000000..f3a1063871
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types.{ObjectType, StructType}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ClassEncoder[T](
+ schema: StructType,
+ extractExpressions: Seq[Expression],
+ constructExpression: Expression,
+ clsTag: ClassTag[T])
+ extends Encoder[T] {
+
+ private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private val inputRow = new GenericMutableRow(1)
+
+ private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+ private val dataType = ObjectType(clsTag.runtimeClass)
+
+ override def toRow(t: T): InternalRow = {
+ if (t == null) {
+ null
+ } else {
+ inputRow(0) = t
+ extractProjection(inputRow)
+ }
+ }
+
+ override def fromRow(row: InternalRow): T = {
+ if (row eq null) {
+ null.asInstanceOf[T]
+ } else {
+ constructProjection(row).get(0, dataType).asInstanceOf[T]
+ }
+ }
+
+ override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
+ val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ val resolvedExpression = analyzedPlan.expressions.head.children.head
+ val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
+
+ copy(constructExpression = boundExpression)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 3618247d5d..bdb1c0959d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -46,7 +46,7 @@ trait Encoder[T] {
/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must bind` and encoder to a specific schema before you can call this function.
+ * you must bind the encoder to a specific schema before you can call this function.
*/
def fromRow(row: InternalRow): T
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index b0381880c3..4f7ce455ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -17,15 +17,11 @@
package org.apache.spark.sql.catalyst.encoders
-import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
-import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{ObjectType, StructType}
/**
@@ -44,44 +40,6 @@ object ProductEncoder {
val constructExpression = ScalaReflection.constructorFor[T]
new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
}
-}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
- schema: StructType,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
- clsTag: ClassTag[T])
- extends Encoder[T] {
- private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
- private val inputRow = new GenericMutableRow(1)
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
- private val dataType = ObjectType(clsTag.runtimeClass)
-
- override def toRow(t: T): InternalRow = {
- inputRow(0) = t
- extractProjection(inputRow)
- }
-
- override def fromRow(row: InternalRow): T = {
- constructProjection(row).get(0, dataType).asInstanceOf[T]
- }
-
- override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
- val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- val resolvedExpression = analyzedPlan.expressions.head.children.head
- val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
- copy(constructExpression = boundExpression)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
new file mode 100644
index 0000000000..3e74aabd07
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.collection.Map
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+object RowEncoder {
+
+ def apply(schema: StructType): ClassEncoder[Row] = {
+ val cls = classOf[Row]
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val extractExpressions = extractorsFor(inputObject, schema)
+ val constructExpression = constructorFor(schema)
+ new ClassEncoder[Row](
+ schema,
+ extractExpressions.asInstanceOf[CreateStruct].children,
+ constructExpression,
+ ClassTag(cls))
+ }
+
+ private def extractorsFor(
+ inputObject: Expression,
+ inputType: DataType): Expression = inputType match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType => inputObject
+
+ case TimestampType =>
+ StaticInvoke(
+ DateTimeUtils,
+ TimestampType,
+ "fromJavaTimestamp",
+ inputObject :: Nil)
+
+ case DateType =>
+ StaticInvoke(
+ DateTimeUtils,
+ DateType,
+ "fromJavaDate",
+ inputObject :: Nil)
+
+ case _: DecimalType =>
+ StaticInvoke(
+ Decimal,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case StringType =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil)
+
+ case t @ ArrayType(et, _) => et match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
+ NewInstance(
+ classOf[GenericArrayData],
+ inputObject :: Nil,
+ dataType = t)
+ case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
+ }
+
+ case t @ MapType(kt, vt, valueNullable) =>
+ val keys =
+ Invoke(
+ Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
+
+ val values =
+ Invoke(
+ Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
+
+ NewInstance(
+ classOf[ArrayBasedMapData],
+ convertedKeys :: convertedValues :: Nil,
+ dataType = t)
+
+ case StructType(fields) =>
+ val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ If(
+ Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
+ Literal.create(null, f.dataType),
+ extractorsFor(
+ Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
+ f.dataType))
+ }
+ CreateStruct(convertedFields)
+ }
+
+ private def externalDataTypeFor(dt: DataType): DataType = dt match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType => dt
+ case TimestampType => ObjectType(classOf[java.sql.Timestamp])
+ case DateType => ObjectType(classOf[java.sql.Date])
+ case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
+ case StringType => ObjectType(classOf[java.lang.String])
+ case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
+ case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
+ case _: StructType => ObjectType(classOf[Row])
+ }
+
+ private def constructorFor(schema: StructType): Expression = {
+ val fields = schema.zipWithIndex.map { case (f, i) =>
+ val field = BoundReference(i, f.dataType, f.nullable)
+ If(
+ IsNull(field),
+ Literal.create(null, externalDataTypeFor(f.dataType)),
+ constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
+ )
+ }
+ CreateRow(fields)
+ }
+
+ private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType => input
+
+ case TimestampType =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Timestamp]),
+ "toJavaTimestamp",
+ input :: Nil)
+
+ case DateType =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Date]),
+ "toJavaDate",
+ input :: Nil)
+
+ case _: DecimalType =>
+ Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+ case StringType =>
+ Invoke(input, "toString", ObjectType(classOf[String]))
+
+ case ArrayType(et, nullable) =>
+ val arrayData =
+ Invoke(
+ MapObjects(constructorFor(_, et), input, et),
+ "array",
+ ObjectType(classOf[Array[_]]))
+ StaticInvoke(
+ scala.collection.mutable.WrappedArray,
+ ObjectType(classOf[Seq[_]]),
+ "make",
+ arrayData :: Nil)
+
+ case MapType(kt, vt, valueNullable) =>
+ val keyArrayType = ArrayType(kt, false)
+ val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
+
+ val valueArrayType = ArrayType(vt, valueNullable)
+ val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
+
+ StaticInvoke(
+ ArrayBasedMapData,
+ ObjectType(classOf[Map[_, _]]),
+ "toScalaMap",
+ keyData :: valueData :: Nil)
+
+ case StructType(fields) =>
+ val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ If(
+ Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
+ Literal.create(null, externalDataTypeFor(f.dataType)),
+ constructorFor(getField(input, i, f.dataType), f.dataType))
+ }
+ CreateRow(convertedFields)
+ }
+
+ private def getField(
+ row: Expression,
+ ordinal: Int,
+ dataType: DataType): Expression = dataType match {
+ case BooleanType =>
+ Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
+ case ByteType =>
+ Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
+ case ShortType =>
+ Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
+ case IntegerType | DateType =>
+ Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
+ case LongType | TimestampType =>
+ Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
+ case FloatType =>
+ Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
+ case DoubleType =>
+ Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
+ case t: DecimalType =>
+ Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
+ case StringType =>
+ Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
+ case BinaryType =>
+ Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
+ case CalendarIntervalType =>
+ Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
+ case t: StructType =>
+ Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
+ case _: ArrayType =>
+ Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
+ case _: MapType =>
+ Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e8c1c93cf5..8fc00ad1bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import scala.language.existentials
-import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
@@ -364,6 +365,10 @@ case class MapObjects(
(".numElements()", (i: String) => s".getShort($i)", true)
case ArrayType(BooleanType, _) =>
(".numElements()", (i: String) => s".getBoolean($i)", true)
+ case ArrayType(StringType, _) =>
+ (".numElements()", (i: String) => s".getUTF8String($i)", false)
+ case ArrayType(_: MapType, _) =>
+ (".numElements()", (i: String) => s".getMap($i)", false)
}
override def nullable: Boolean = true
@@ -398,7 +403,7 @@ case class MapObjects(
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
- val convertedType = ctx.javaType(boundFunction.dataType)
+ val convertedType = ctx.boxedType(boundFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -434,9 +439,13 @@ case class MapObjects(
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
- ${genFunction.code}
+ if ($loopIsNull) {
+ $convertedArray[$loopIndex] = null;
+ } else {
+ ${genFunction.code}
+ $convertedArray[$loopIndex] = ${genFunction.value};
+ }
- $convertedArray[$loopIndex] = ($convertedType)${genFunction.value};
$loopIndex += 1;
}
@@ -446,3 +455,32 @@ case class MapObjects(
"""
}
}
+
+case class CreateRow(children: Seq[Expression]) extends Expression {
+ override def dataType: DataType = ObjectType(classOf[Row])
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val rowClass = classOf[GenericRow].getName
+ val values = ctx.freshName("values")
+ s"""
+ boolean ${ev.isNull} = false;
+ final Object[] $values = new Object[${children.size}];
+ """ +
+ children.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
+ if (${eval.isNull}) {
+ $values[$i] = null;
+ } else {
+ $values[$i] = ${eval.value};
+ }
+ """
+ }.mkString("\n") +
+ s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index 5f22e59d5f..e5ffe32217 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -66,4 +66,8 @@ object ArrayBasedMapData {
def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
keys.zip(values).toMap
}
+
+ def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
+ keys.zip(values).toMap
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index e48395028e..7614f055e9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -148,7 +148,7 @@ object RandomDataGenerator {
() => BigDecimal.apply(
rand.nextLong() % math.pow(10, precision).toLong,
scale,
- new MathContext(precision)))
+ new MathContext(precision)).bigDecimal)
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
@@ -166,7 +166,7 @@ object RandomDataGenerator {
case NullType => Some(() => null)
case ArrayType(elementType, containsNull) => {
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
- elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
+ elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
}
case MapType(keyType, valueType, valueContainsNull) => {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
new file mode 100644
index 0000000000..6041b62b74
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class RowEncoderSuite extends SparkFunSuite {
+
+ private val structOfString = new StructType().add("str", StringType)
+ private val arrayOfString = ArrayType(StringType)
+ private val mapOfString = MapType(StringType, StringType)
+
+ encodeDecodeTest(
+ new StructType()
+ .add("boolean", BooleanType)
+ .add("byte", ByteType)
+ .add("short", ShortType)
+ .add("int", IntegerType)
+ .add("long", LongType)
+ .add("float", FloatType)
+ .add("double", DoubleType)
+ .add("decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("string", StringType)
+ .add("binary", BinaryType)
+ .add("date", DateType)
+ .add("timestamp", TimestampType))
+
+ encodeDecodeTest(
+ new StructType()
+ .add("arrayOfString", arrayOfString)
+ .add("arrayOfArrayOfString", ArrayType(arrayOfString))
+ .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
+ .add("arrayOfMap", ArrayType(mapOfString))
+ .add("arrayOfStruct", ArrayType(structOfString)))
+
+ encodeDecodeTest(
+ new StructType()
+ .add("mapOfIntAndString", MapType(IntegerType, StringType))
+ .add("mapOfStringAndArray", MapType(StringType, arrayOfString))
+ .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
+ .add("mapOfArray", MapType(arrayOfString, arrayOfString))
+ .add("mapOfStringAndStruct", MapType(StringType, structOfString))
+ .add("mapOfStructAndString", MapType(structOfString, StringType))
+ .add("mapOfStruct", MapType(structOfString, structOfString)))
+
+ encodeDecodeTest(
+ new StructType()
+ .add("structOfString", structOfString)
+ .add("structOfStructOfString", new StructType().add("struct", structOfString))
+ .add("structOfArray", new StructType().add("array", arrayOfString))
+ .add("structOfMap", new StructType().add("map", mapOfString))
+ .add("structOfArrayAndMap",
+ new StructType().add("array", arrayOfString).add("map", mapOfString)))
+
+ private def encodeDecodeTest(schema: StructType): Unit = {
+ test(s"encode/decode: ${schema.simpleString}") {
+ val encoder = RowEncoder(schema)
+ val inputGenerator = RandomDataGenerator.forType(schema).get
+
+ var input: Row = null
+ try {
+ for (_ <- 1 to 5) {
+ input = inputGenerator.apply().asInstanceOf[Row]
+ val row = encoder.toRow(input)
+ val convertedBack = encoder.fromRow(row)
+ assert(input == convertedBack)
+ }
+ } catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |schema: ${schema.simpleString}
+ |input: ${input}
+ """.stripMargin, e)
+ }
+ }
+ }
+}