aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala217
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala)29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala173
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala)39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala190
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala44
18 files changed, 563 insertions, 615 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c25161ee81..9cbb7c2ffd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,6 +146,10 @@ trait ScalaReflection {
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
+ *
+ * When used on a primitive type, the constructor will instead default to extracting the value
+ * from ordinal 0 (since there are no names to map to). The actual location can be moved by
+ * calling unbind/bind with a new schema.
*/
def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
@@ -159,8 +163,14 @@ trait ScalaReflection {
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
+ /** Returns the current path with a field at ordinal extracted. */
+ def addToPathOrdinal(ordinal: Int, dataType: DataType) =
+ path
+ .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+
/** Returns the current path or throws an error. */
- def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+ def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true))
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
@@ -387,12 +397,17 @@ trait ScalaReflection {
val className: String = t.erasure.typeSymbol.asClass.fullName
val cls = Utils.classForName(className)
- val arguments = params.head.map { p =>
+ val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- val dataType = dataTypeFor(fieldType)
+ val dataType = schemaFor(fieldType).dataType
- constructorFor(fieldType, Some(addToPath(fieldName)))
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ if (className startsWith "scala.Tuple") {
+ constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+ } else {
+ constructorFor(fieldType, Some(addToPath(fieldName)))
+ }
}
val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
@@ -413,7 +428,10 @@ trait ScalaReflection {
/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
+ extractorFor(inputObject, typeTag[T].tpe) match {
+ case s: CreateNamedStruct => s
+ case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
+ }
}
}
@@ -602,6 +620,21 @@ trait ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
+ case t if t <:< definitions.IntTpe =>
+ BoundReference(0, IntegerType, false)
+ case t if t <:< definitions.LongTpe =>
+ BoundReference(0, LongType, false)
+ case t if t <:< definitions.DoubleTpe =>
+ BoundReference(0, DoubleType, false)
+ case t if t <:< definitions.FloatTpe =>
+ BoundReference(0, FloatType, false)
+ case t if t <:< definitions.ShortTpe =>
+ BoundReference(0, ShortType, false)
+ case t if t <:< definitions.ByteTpe =>
+ BoundReference(0, ByteType, false)
+ case t if t <:< definitions.BooleanTpe =>
+ BoundReference(0, BooleanType, false)
+
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
deleted file mode 100644
index b484b8fde6..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
- schema: StructType,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
- clsTag: ClassTag[T])
- extends Encoder[T] {
-
- @transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
- private val inputRow = new GenericMutableRow(1)
-
- @transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
- private val dataType = ObjectType(clsTag.runtimeClass)
-
- override def toRow(t: T): InternalRow = {
- inputRow(0) = t
- extractProjection(inputRow)
- }
-
- override def fromRow(row: InternalRow): T = {
- constructProjection(row).get(0, dataType).asInstanceOf[T]
- }
-
- override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
- val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- val resolvedExpression = analyzedPlan.expressions.head.children.head
- val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
- copy(constructExpression = boundExpression)
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
- var remaining = schema
- copy(constructExpression = constructExpression transform {
- case u: UnresolvedAttribute =>
- val pos = remaining.head
- remaining = remaining.drop(1)
- pos
- })
- }
-
- protected val attrs = extractExpressions.map(_.collect {
- case a: Attribute => s"#${a.exprId}"
- case b: BoundReference => s"[${b.ordinal}]"
- }.headOption.getOrElse(""))
-
-
- protected val schemaString =
- schema
- .zip(attrs)
- .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
-
- override def toString: String = s"class[$schemaString]"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index efb872ddb8..329a132d3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -18,10 +18,9 @@
package org.apache.spark.sql.catalyst.encoders
+
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
/**
@@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
*/
-trait Encoder[T] {
+trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
-
- /**
- * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
- * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
- * copy the result before making another call if required.
- */
- def toRow(t: T): InternalRow
-
- /**
- * Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must `bind` an encoder to a specific schema before you can call this function.
- */
- def fromRow(row: InternalRow): T
-
- /**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
- * given schema.
- */
- def bind(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Binds this encoder to the given schema positionally. In this binding, the first reference to
- * any input is mapped to `schema(0)`, and so on for each input that is encountered.
- */
- def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Given an encoder that has already been bound to a given schema, returns a new encoder that
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
- * row, but now you have projected out only the key expressions.
- */
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
new file mode 100644
index 0000000000..c287aebeee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert objects and primitves to and from the
+ * internal row format using catalyst expressions and code generation. By default, the
+ * expressions used to retrieve values from an input row when producing an object will be created as
+ * follows:
+ * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
+ * and [[UnresolvedExtractValue]] expressions.
+ * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
+ * - Primitives will have their values extracted from the first ordinal with a schema that defaults
+ * to the name `value`.
+ */
+object ExpressionEncoder {
+ def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+ // We convert the not-serializable TypeTag into StructType and ClassTag.
+ val mirror = typeTag[T].mirror
+ val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
+ val constructExpression = ScalaReflection.constructorFor[T]
+
+ new ExpressionEncoder[T](
+ extractExpression.dataType,
+ flat,
+ extractExpression.flatten,
+ constructExpression,
+ ClassTag[T](cls))
+ }
+
+ /**
+ * Given a set of N encoders, constructs a new encoder that produce objects as items in an
+ * N-tuple. Note that these encoders should first be bound correctly to the combined input
+ * schema.
+ */
+ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ val schema =
+ StructType(
+ encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+ val extractExpressions = encoders.map {
+ case e if e.flat => e.extractExpressions.head
+ case other => CreateStruct(other.extractExpressions)
+ }
+ val constructExpression =
+ NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+
+ new ExpressionEncoder[Any](
+ schema,
+ false,
+ extractExpressions,
+ constructExpression,
+ ClassTag.apply(cls))
+ }
+
+ /** A helper for producing encoders of Tuple2 from other encoders. */
+ def tuple[T1, T2](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
+ tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ExpressionEncoder[T](
+ schema: StructType,
+ flat: Boolean,
+ extractExpressions: Seq[Expression],
+ constructExpression: Expression,
+ clsTag: ClassTag[T])
+ extends Encoder[T] {
+
+ if (flat) require(extractExpressions.size == 1)
+
+ @transient
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private val inputRow = new GenericMutableRow(1)
+
+ @transient
+ private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+
+ /**
+ * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
+ * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
+ * copy the result before making another call if required.
+ */
+ def toRow(t: T): InternalRow = {
+ inputRow(0) = t
+ extractProjection(inputRow)
+ }
+
+ /**
+ * Returns an object of type `T`, extracting the required values from the provided row. Note that
+ * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+ * function.
+ */
+ def fromRow(row: InternalRow): T = try {
+ constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ }
+
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+ * given schema.
+ */
+ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to construct an object from an input
+ * row have been bound to the ordinals of the given schema. Note that you need to first call
+ * resolve before bind.
+ */
+ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
+ }
+
+ /**
+ * Replaces any bound references in the schema with the attributes at the corresponding ordinal
+ * in the provided schema. This can be used to "relocate" a given encoder to pull values from
+ * a different schema than it was initially bound to. It can also be used to assign attributes
+ * to ordinal based extraction (i.e. because the input data was a tuple).
+ */
+ def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ copy(constructExpression = constructExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ })
+ }
+
+ /**
+ * Given an encoder that has already been bound to a given schema, returns a new encoder
+ * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
+ * when you are trying to use an encoder on grouping keys that were originally part of a larger
+ * row, but now you have projected out only the key expressions.
+ */
+ def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(oldSchema)
+ val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference =>
+ r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+ })
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to create an object given an
+ * input row have been modified to pull the object out from a nested struct, instead of the
+ * top level fields.
+ */
+ def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
+ copy(constructExpression = constructExpression transform {
+ case u: Attribute if u != input =>
+ UnresolvedExtractValue(input, Literal(u.name))
+ case b: BoundReference if b != input =>
+ GetStructField(
+ input,
+ StructField(s"i[${b.ordinal}]", b.dataType),
+ b.ordinal)
+ })
+ }
+
+ protected val attrs = extractExpressions.flatMap(_.collect {
+ case _: UnresolvedAttribute => ""
+ case a: Attribute => s"#${a.exprId}"
+ case b: BoundReference => s"[${b.ordinal}]"
+ })
+
+ protected val schemaString =
+ schema
+ .zip(attrs)
+ .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+ override def toString: String = s"class[$schemaString]"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 34f5e6c030..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
- * internal binary representation.
- */
-object ProductEncoder {
- def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
- // We convert the not-serializable TypeTag into StructType and ClassTag.
- val mirror = typeTag[T].mirror
- val cls = mirror.runtimeClass(typeTag[T].tpe)
-
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
- val constructExpression = ScalaReflection.constructorFor[T]
-
- new ClassEncoder[T](
- extractExpression.dataType,
- extractExpression.flatten,
- constructExpression,
- ClassTag[T](cls))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e9cc00a2b6..0b42130a01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
* internal binary representation.
*/
object RowEncoder {
- def apply(schema: StructType): ClassEncoder[Row] = {
+ def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = extractorsFor(inputObject, schema)
val constructExpression = constructorFor(schema)
- new ClassEncoder[Row](
+ new ExpressionEncoder[Row](
schema,
+ flat = false,
extractExpressions.asInstanceOf[CreateStruct].children,
constructExpression,
ClassTag(cls))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 52f8383fac..d4642a5006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -15,29 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql.catalyst
-import org.apache.spark.SparkFunSuite
-
-class PrimitiveEncoderSuite extends SparkFunSuite {
- test("long encoder") {
- val enc = new LongEncoder()
- val row = enc.toRow(10)
- assert(row.getLong(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("int encoder") {
- val enc = new IntEncoder()
- val row = enc.toRow(10)
- assert(row.getInt(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("string encoder") {
- val enc = new StringEncoder()
- val row = enc.toRow("test")
- assert(row.getString(0) == "test")
- assert(enc.fromRow(row) == "test")
+package object encoders {
+ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
+ case e: ExpressionEncoder[A] => e
+ case _ => sys.error(s"Only expression encoders are supported today")
}
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
deleted file mode 100644
index a93f2d7c61..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-
-/** An encoder for primitive Long types. */
-case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Long] = ClassTag.Long
- override def schema: StructType =
- StructType(StructField(fieldName, LongType) :: Nil)
-
- override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
-
- override def toRow(t: Long): InternalRow = {
- row.setLong(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
- override def bind(schema: Seq[Attribute]): Encoder[Long] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
-}
-
-/** An encoder for primitive Integer types. */
-case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Int] = ClassTag.Int
- override def schema: StructType =
- StructType(StructField(fieldName, IntegerType) :: Nil)
-
- override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
-
- override def toRow(t: Int): InternalRow = {
- row.setInt(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
- override def bind(schema: Seq[Attribute]): Encoder[Int] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
-}
-
-/** An encoder for String types. */
-case class StringEncoder(
- fieldName: String = "value",
- ordinal: Int = 0) extends Encoder[String] {
-
- val record = new SpecificMutableRow(StringType :: Nil)
-
- @transient
- lazy val projection =
- GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
-
- override def schema: StructType =
- StructType(
- StructField("value", StringType, nullable = false) :: Nil)
-
- override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
-
-
- override final def fromRow(row: InternalRow): String = {
- row.getString(ordinal)
- }
-
- override final def toRow(value: String): InternalRow = {
- val utf8String = UTF8String.fromString(value)
- record(0) = utf8String
- // TODO: this is a bit of a hack to produce UnsafeRows
- projection(record)
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
- override def bind(schema: Seq[Attribute]): Encoder[String] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
deleted file mode 100644
index a48eeda7d2..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructField, StructType}
-
-// Most of this file is codegen.
-// scalastyle:off
-
-/**
- * A set of composite encoders that take sub encoders and map each of their objects to a
- * Scala tuple. Note that currently the implementation is fairly limited and only supports going
- * from an internal row to a tuple.
- */
-object TupleEncoder {
-
- /** Code generator for composite tuple encoders. */
- def main(args: Array[String]): Unit = {
- (2 to 5).foreach { i =>
- val types = (1 to i).map(t => s"T$t").mkString(", ")
- val tupleType = s"($types)"
- val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
- val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
- val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
-
- println(
- s"""
- |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
- | val schema = StructType(Array($fields))
- |
- | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
- |
- | def fromRow(row: InternalRow): $tupleType = {
- | ($fromRow)
- | }
- |
- | override def toRow(t: $tupleType): InternalRow =
- | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
- |
- | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
- | this
- | }
- |
- | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |
- |
- | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |}
- """.stripMargin)
- }
- }
-}
-
-class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
-
- def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
-
- def fromRow(row: InternalRow): (T1, T2) = {
- (e1.fromRow(row), e2.fromRow(row))
- }
-
- override def toRow(t: (T1, T2)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
-
- def fromRow(row: InternalRow): (T1, T2, T3) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21a55a5371..d2d3db0a44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
@@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode {
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
@@ -460,8 +460,8 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
- val attrs = implicitly[Encoder[U]].schema.toAttributes
- new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+ val attrs = encoderFor[U].schema.toAttributes
+ new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
}
}
@@ -472,8 +472,8 @@ object AppendColumn {
*/
case class AppendColumn[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
@@ -488,11 +488,11 @@ object MapGroups {
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- implicitly[Encoder[K]],
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
+ encoderFor[K],
+ encoderFor[T],
+ encoderFor[U],
groupingAttributes,
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[U].schema.toAttributes,
child)
}
}
@@ -504,9 +504,9 @@ object MapGroups {
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 008d0bea8a..a374da4da1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -47,7 +47,16 @@ case class RepeatedData(
case class SpecificCollection(l: List[Int])
-class ProductEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends SparkFunSuite {
+
+ encodeDecodeTest(1)
+ encodeDecodeTest(1L)
+ encodeDecodeTest(1.toDouble)
+ encodeDecodeTest(1.toFloat)
+ encodeDecodeTest(true)
+ encodeDecodeTest(false)
+ encodeDecodeTest(1.toShort)
+ encodeDecodeTest(1.toByte)
encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
@@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite {
{ (l, r) => l._2.toString == r._2.toString }
/** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
- protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+ protected def encodeDecodeTest[T : TypeTag](inputData: T) =
encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
/**
* Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
* matches the original.
*/
- protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+ protected def encodeDecodeTestCustom[T : TypeTag](
inputData: T)(
c: (T, T) => Boolean) = {
- test(s"encode/decode: $inputData") {
- val encoder = try ProductEncoder[T] catch {
+ test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
+ val encoder = try ExpressionEncoder[T]() catch {
case e: Exception =>
fail(s"Exception thrown generating encoder", e)
}
val convertedData = encoder.toRow(inputData)
val schema = encoder.schema.toAttributes
- val boundEncoder = encoder.bind(schema)
+ val boundEncoder = encoder.resolve(schema).bind(schema)
val convertedBack = try boundEncoder.fromRow(convertedData) catch {
case e: Exception =>
fail(
@@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite {
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
+ |Encoder:
+ |$boundEncoder
|
""".stripMargin, e)
}
if (!c(inputData, convertedBack)) {
- val types =
- convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ val types = convertedBack match {
+ case c: Product =>
+ c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ case other => other.getClass.getName
+ }
+
val encodedData = try {
convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
@@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite {
|${encoder.schema.treeString}
|
|Extract Expressions:
- |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
- |
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
- |
+ |$boundEncoder
""".stripMargin)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b1d9..aa817a037e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7630..e0ab5f593e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0
*/
@Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
@transient val sqlContext: SQLContext,
- @transient val queryExecution: QueryExecution)(
- implicit val encoder: Encoder[T]) extends Serializable {
+ @transient val queryExecution: QueryExecution,
+ unresolvedEncoder: Encoder[T]) extends Serializable {
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+ case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+ case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+ }
private implicit def classTag = encoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan))
+ this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
/** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
* TODO: document binding rules
* @since 1.6.0
*/
- def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+ def as[U : Encoder]: Dataset[U] = {
+ new Dataset(sqlContext, queryExecution, encoderFor[U])
+ }
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = implicitly[Encoder[T]]
+ val tEnc = encoderFor[T]
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
sqlContext,
MapPartitions[T, U](
func,
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[T],
+ encoderFor[U],
+ encoderFor[U].schema.toAttributes,
logicalPlan))
}
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
- implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
- implicitly[Encoder[T]].bind(inputPlan.output),
+ encoderFor[K].resolve(withGroupingKey.newColumns),
+ encoderFor[T].bind(inputPlan.output),
executed,
inputPlan.output,
withGroupingKey.newColumns)
@@ -221,6 +230,18 @@ class Dataset[T] private[sql](
* ****************** */
/**
+ * Selects a set of column based expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ * @group dfops
+ * @since 1.3.0
+ */
+ // Copied from Dataframe to make sure we don't have invalid overloads.
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
+ /**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
*
* {{{
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
- // Codegen
- // scalastyle:off
-
- /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
- private def genSelect: String = {
- (2 to 5).map { n =>
- val types = (1 to n).map(i =>s"U$i").mkString(", ")
- val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
- val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
- val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
- s"""
- |/**
- | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- | * @since 1.6.0
- | */
- |def select[$types]($args): Dataset[($types)] = {
- | implicit val te = new Tuple${n}Encoder($encoders)
- | new Dataset[($types)](sqlContext,
- | Project(
- | $schema :: Nil,
- | logicalPlan))
- |}
- |
- """.stripMargin
- }.mkString("\n")
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+ val unresolvedPlan = Project(aliases, logicalPlan)
+ val execution = new QueryExecution(sqlContext, unresolvedPlan)
+ // Rebind the encoders to the nested schema that will be produced by the select.
+ val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+ case (e: ExpressionEncoder[_], a) if !e.flat =>
+ e.nested(a.toAttribute).resolve(execution.analyzed.output)
+ case (e, a) =>
+ e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+ }
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
- implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
- new Dataset[(U1, U2)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
- implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
- new Dataset[(U1, U2, U3)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
- implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
- new Dataset[(U1, U2, U3, U4)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
- implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
- new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
- logicalPlan))
- }
-
- // scalastyle:on
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4],
+ c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *
* Set operations *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
*/
def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+ /* ****** *
+ * Joins *
+ * ****** */
+
+ /**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val leftData = this.encoder match {
+ case e if e.flat => Alias(left.output.head, "_1")()
+ case _ => Alias(CreateStruct(left.output), "_1")()
+ }
+ val rightData = other.encoder match {
+ case e if e.flat => Alias(right.output.head, "_2")()
+ case _ => Alias(CreateStruct(right.output), "_2")()
+ }
+ val leftEncoder =
+ if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+ val rightEncoder =
+ if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+ withPlan[(T, U)](other) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ Join(left, right, Inner, Some(condition.expr)))
+ }
+ }
+
/* ************************** *
* Gather to Driver Actions *
* ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
private[sql] def withPlan[R : Encoder](
other: Dataset[_])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](
- sqlContext,
- sqlContext.executePlan(
- f(logicalPlan, other.logicalPlan)))
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f974..2cb94430e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- val enc = implicitly[Encoder[T]]
+ val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474df0d..f460a86414 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
- implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
- implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
- implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(s))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba5bd..89938471ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
*/
case class AppendColumns[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 08496249c6..aebb390a1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
data: _*)
}
+ test("as tuple") {
+ val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+ checkAnswer(
+ data.as[(String, Int)],
+ ("a", 1), ("b", 2))
+ }
+
test("as case class / collect") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
- test("select 3") {
+ test("select 2") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
ds.select(
expr("_1").as[String],
- expr("_2").as[Int],
- expr("_2 + 1").as[Int]),
- ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ expr("_2").as[Int]) : Dataset[(String, Int)],
+ ("a", 1), ("b", 2), ("c", 3))
+ }
+
+ test("select 2, primitive and tuple") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("struct(_2, _2)").as[(Int, Int)]),
+ ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+ }
+
+ test("select 2, primitive and class") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
+ }
+
+ test("select 2, primitive and class, fields reordered") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkDecoding(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
}
test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
+ test("joinWith, flat schema") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a.value" === $"b.value"),
+ (1, 1), (2, 2))
+ }
+
+ test("joinWith, expression condition") {
+ val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"_1" === $"a"),
+ (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+ }
+
+ test("joinWith tuple with primitive, expression") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"_2"),
+ (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+ }
+
+ test("joinWith class with primitive, toDF") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
+ Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+ }
+
+ test("multi-level joinWith") {
+ val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
+ ((("a", 1), ("a", 1)), ("a", 1)),
+ ((("b", 2), ("b", 2)), ("b", 2)))
+
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupBy(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba567512f..73e02eb0d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest {
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
}
}
- protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ * - Special handling is done based on whether the query plan should be expected to return
+ * the results in sorted order.
+ * - This function also checks to make sure that the schema for serializing the expected answer
+ * matches that produced by the dataset (i.e. does manual construction of object match
+ * the constructed encoder for cases like joins, etc). Note that this means that it will fail
+ * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
+ * which performs a subset of the checks done by this function.
+ */
+ protected def checkAnswer[T : Encoder](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+ checkDecoding(ds, expectedAnswer: _*)
+ }
+
+ protected def checkDecoding[T](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val decoded = try ds.collect().toSet catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.encoder}
+ |${ds.encoder.constructExpression.treeString}
+ |${ds.queryExecution}
+ """.stripMargin, e)
+ }
+
+ if (decoded != expectedAnswer.toSet) {
+ fail(
+ s"""Decoded objects do not match expected objects:
+ |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |${ds.encoder.constructExpression.treeString}
+ """.stripMargin)
+ }
}
/**