aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-10-27 13:28:52 -0700
committerYin Huai <yhuai@databricks.com>2015-10-27 13:28:52 -0700
commit5a5f65905a202e59bc85170b01c57a883718ddf6 (patch)
treedcd1f9958573a0e3b419805609495fc1380b1565
parent3bdbbc6c972567861044dd6a6dc82f35cd12442d (diff)
downloadspark-5a5f65905a202e59bc85170b01c57a883718ddf6.tar.gz
spark-5a5f65905a202e59bc85170b01c57a883718ddf6.tar.bz2
spark-5a5f65905a202e59bc85170b01c57a883718ddf6.zip
[SPARK-11347] [SQL] Support for joinWith in Datasets
This PR adds a new operation `joinWith` to a `Dataset`, which returns a `Tuple` for each pair where a given `condition` evaluates to true. ```scala case class ClassData(a: String, b: Int) val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() > ds1.joinWith(ds2, $"_1" === $"a").collect() res0: Array((ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) ``` This operation is similar to the relation `join` function with one important difference in the result schema. Since `joinWith` preserves objects present on either side of the join, the result schema is similarly nested into a tuple under the column names `_1` and `_2`. This type of join can be useful both for preserving type-safety with the original object types as well as working with relational data where either side of the join has column names in common. ## Required Changes to Encoders In the process of working on this patch, several deficiencies to the way that we were handling encoders were discovered. Specifically, it turned out to be very difficult to `rebind` the non-expression based encoders to extract the nested objects from the results of joins (and also typed selects that return tuples). As a result the following changes were made. - `ClassEncoder` has been renamed to `ExpressionEncoder` and has been improved to also handle primitive types. Additionally, it is now possible to take arbitrary expression encoders and rewrite them into a single encoder that returns a tuple. - All internal operations on `Dataset`s now require an `ExpressionEncoder`. If the users tries to pass a non-`ExpressionEncoder` in, an error will be thrown. We can relax this requirement in the future by constructing a wrapper class that uses expressions to project the row to the expected schema, shielding the users code from the required remapping. This will give us a nice balance where we don't force user encoders to understand attribute references and binding, but still allow our native encoder to leverage runtime code generation to construct specific encoders for a given schema that avoid an extra remapping step. - Additionally, the semantics for different types of objects are now better defined. As stated in the `ExpressionEncoder` scaladoc: - Classes will have their sub fields extracted by name using `UnresolvedAttribute` expressions and `UnresolvedExtractValue` expressions. - Tuples will have their subfields extracted by position using `BoundReference` expressions. - Primitives will have their values extracted from the first ordinal with a schema that defaults to the name `value`. - Finally, the binding lifecycle for `Encoders` has now been unified across the codebase. Encoders are now `resolved` to the appropriate schema in the constructor of `Dataset`. This process replaces an unresolved expressions with concrete `AttributeReference` expressions. Binding then happens on demand, when an encoder is going to be used to construct an object. This closely mirrors the lifecycle for standard expressions when executing normal SQL or `DataFrame` queries. Author: Michael Armbrust <michael@databricks.com> Closes #9300 from marmbrus/datasets-tuples.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala101
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala217
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala)29
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala100
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala173
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala)39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala190
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala44
18 files changed, 563 insertions, 615 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c25161ee81..9cbb7c2ffd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -146,6 +146,10 @@ trait ScalaReflection {
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
+ *
+ * When used on a primitive type, the constructor will instead default to extracting the value
+ * from ordinal 0 (since there are no names to map to). The actual location can be moved by
+ * calling unbind/bind with a new schema.
*/
def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
@@ -159,8 +163,14 @@ trait ScalaReflection {
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
+ /** Returns the current path with a field at ordinal extracted. */
+ def addToPathOrdinal(ordinal: Int, dataType: DataType) =
+ path
+ .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+
/** Returns the current path or throws an error. */
- def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+ def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true))
tpe match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
@@ -387,12 +397,17 @@ trait ScalaReflection {
val className: String = t.erasure.typeSymbol.asClass.fullName
val cls = Utils.classForName(className)
- val arguments = params.head.map { p =>
+ val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- val dataType = dataTypeFor(fieldType)
+ val dataType = schemaFor(fieldType).dataType
- constructorFor(fieldType, Some(addToPath(fieldName)))
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ if (className startsWith "scala.Tuple") {
+ constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+ } else {
+ constructorFor(fieldType, Some(addToPath(fieldName)))
+ }
}
val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
@@ -413,7 +428,10 @@ trait ScalaReflection {
/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct]
+ extractorFor(inputObject, typeTag[T].tpe) match {
+ case s: CreateNamedStruct => s
+ case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
+ }
}
}
@@ -602,6 +620,21 @@ trait ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
+ case t if t <:< definitions.IntTpe =>
+ BoundReference(0, IntegerType, false)
+ case t if t <:< definitions.LongTpe =>
+ BoundReference(0, LongType, false)
+ case t if t <:< definitions.DoubleTpe =>
+ BoundReference(0, DoubleType, false)
+ case t if t <:< definitions.FloatTpe =>
+ BoundReference(0, FloatType, false)
+ case t if t <:< definitions.ShortTpe =>
+ BoundReference(0, ShortType, false)
+ case t if t <:< definitions.ByteTpe =>
+ BoundReference(0, ByteType, false)
+ case t if t <:< definitions.BooleanTpe =>
+ BoundReference(0, BooleanType, false)
+
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
deleted file mode 100644
index b484b8fde6..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A generic encoder for JVM objects.
- *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
- * @param clsTag A classtag for `T`.
- */
-case class ClassEncoder[T](
- schema: StructType,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
- clsTag: ClassTag[T])
- extends Encoder[T] {
-
- @transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
- private val inputRow = new GenericMutableRow(1)
-
- @transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
- private val dataType = ObjectType(clsTag.runtimeClass)
-
- override def toRow(t: T): InternalRow = {
- inputRow(0) = t
- extractProjection(inputRow)
- }
-
- override def fromRow(row: InternalRow): T = {
- constructProjection(row).get(0, dataType).asInstanceOf[T]
- }
-
- override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
- val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- val resolvedExpression = analyzedPlan.expressions.head.children.head
- val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
-
- copy(constructExpression = boundExpression)
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = {
- var remaining = schema
- copy(constructExpression = constructExpression transform {
- case u: UnresolvedAttribute =>
- val pos = remaining.head
- remaining = remaining.drop(1)
- pos
- })
- }
-
- protected val attrs = extractExpressions.map(_.collect {
- case a: Attribute => s"#${a.exprId}"
- case b: BoundReference => s"[${b.ordinal}]"
- }.headOption.getOrElse(""))
-
-
- protected val schemaString =
- schema
- .zip(attrs)
- .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
-
- override def toString: String = s"class[$schemaString]"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index efb872ddb8..329a132d3d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -18,10 +18,9 @@
package org.apache.spark.sql.catalyst.encoders
+
import scala.reflect.ClassTag
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
/**
@@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
* and reuse internal buffers to improve performance.
*/
-trait Encoder[T] {
+trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
def schema: StructType
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
-
- /**
- * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
- * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
- * copy the result before making another call if required.
- */
- def toRow(t: T): InternalRow
-
- /**
- * Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must `bind` an encoder to a specific schema before you can call this function.
- */
- def fromRow(row: InternalRow): T
-
- /**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
- * given schema.
- */
- def bind(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Binds this encoder to the given schema positionally. In this binding, the first reference to
- * any input is mapped to `schema(0)`, and so on for each input that is encountered.
- */
- def bindOrdinals(schema: Seq[Attribute]): Encoder[T]
-
- /**
- * Given an encoder that has already been bound to a given schema, returns a new encoder that
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were orriginally part of a larger
- * row, but now you have projected out only the key expressions.
- */
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
new file mode 100644
index 0000000000..c287aebeee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+
+/**
+ * A factory for constructing encoders that convert objects and primitves to and from the
+ * internal row format using catalyst expressions and code generation. By default, the
+ * expressions used to retrieve values from an input row when producing an object will be created as
+ * follows:
+ * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions
+ * and [[UnresolvedExtractValue]] expressions.
+ * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions.
+ * - Primitives will have their values extracted from the first ordinal with a schema that defaults
+ * to the name `value`.
+ */
+object ExpressionEncoder {
+ def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+ // We convert the not-serializable TypeTag into StructType and ClassTag.
+ val mirror = typeTag[T].mirror
+ val cls = mirror.runtimeClass(typeTag[T].tpe)
+
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
+ val constructExpression = ScalaReflection.constructorFor[T]
+
+ new ExpressionEncoder[T](
+ extractExpression.dataType,
+ flat,
+ extractExpression.flatten,
+ constructExpression,
+ ClassTag[T](cls))
+ }
+
+ /**
+ * Given a set of N encoders, constructs a new encoder that produce objects as items in an
+ * N-tuple. Note that these encoders should first be bound correctly to the combined input
+ * schema.
+ */
+ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ val schema =
+ StructType(
+ encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+ val extractExpressions = encoders.map {
+ case e if e.flat => e.extractExpressions.head
+ case other => CreateStruct(other.extractExpressions)
+ }
+ val constructExpression =
+ NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+
+ new ExpressionEncoder[Any](
+ schema,
+ false,
+ extractExpressions,
+ constructExpression,
+ ClassTag.apply(cls))
+ }
+
+ /** A helper for producing encoders of Tuple2 from other encoders. */
+ def tuple[T1, T2](
+ e1: ExpressionEncoder[T1],
+ e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
+ tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+}
+
+/**
+ * A generic encoder for JVM objects.
+ *
+ * @param schema The schema after converting `T` to a Spark SQL row.
+ * @param extractExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object.
+ * @param clsTag A classtag for `T`.
+ */
+case class ExpressionEncoder[T](
+ schema: StructType,
+ flat: Boolean,
+ extractExpressions: Seq[Expression],
+ constructExpression: Expression,
+ clsTag: ClassTag[T])
+ extends Encoder[T] {
+
+ if (flat) require(extractExpressions.size == 1)
+
+ @transient
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private val inputRow = new GenericMutableRow(1)
+
+ @transient
+ private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+
+ /**
+ * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
+ * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
+ * copy the result before making another call if required.
+ */
+ def toRow(t: T): InternalRow = {
+ inputRow(0) = t
+ extractProjection(inputRow)
+ }
+
+ /**
+ * Returns an object of type `T`, extracting the required values from the provided row. Note that
+ * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+ * function.
+ */
+ def fromRow(row: InternalRow): T = try {
+ constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ }
+
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
+ * given schema.
+ */
+ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to construct an object from an input
+ * row have been bound to the ordinals of the given schema. Note that you need to first call
+ * resolve before bind.
+ */
+ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
+ }
+
+ /**
+ * Replaces any bound references in the schema with the attributes at the corresponding ordinal
+ * in the provided schema. This can be used to "relocate" a given encoder to pull values from
+ * a different schema than it was initially bound to. It can also be used to assign attributes
+ * to ordinal based extraction (i.e. because the input data was a tuple).
+ */
+ def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ copy(constructExpression = constructExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ })
+ }
+
+ /**
+ * Given an encoder that has already been bound to a given schema, returns a new encoder
+ * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
+ * when you are trying to use an encoder on grouping keys that were originally part of a larger
+ * row, but now you have projected out only the key expressions.
+ */
+ def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
+ val positionToAttribute = AttributeMap.toIndex(oldSchema)
+ val attributeToNewPosition = AttributeMap.byIndex(newSchema)
+ copy(constructExpression = constructExpression transform {
+ case r: BoundReference =>
+ r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
+ })
+ }
+
+ /**
+ * Returns a copy of this encoder where the expressions used to create an object given an
+ * input row have been modified to pull the object out from a nested struct, instead of the
+ * top level fields.
+ */
+ def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
+ copy(constructExpression = constructExpression transform {
+ case u: Attribute if u != input =>
+ UnresolvedExtractValue(input, Literal(u.name))
+ case b: BoundReference if b != input =>
+ GetStructField(
+ input,
+ StructField(s"i[${b.ordinal}]", b.dataType),
+ b.ordinal)
+ })
+ }
+
+ protected val attrs = extractExpressions.flatMap(_.collect {
+ case _: UnresolvedAttribute => ""
+ case a: Attribute => s"#${a.exprId}"
+ case b: BoundReference => s"[${b.ordinal}]"
+ })
+
+ protected val schemaString =
+ schema
+ .zip(attrs)
+ .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
+
+ override def toString: String = s"class[$schemaString]"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 34f5e6c030..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{ObjectType, StructType}
-
-/**
- * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL
- * internal binary representation.
- */
-object ProductEncoder {
- def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
- // We convert the not-serializable TypeTag into StructType and ClassTag.
- val mirror = typeTag[T].mirror
- val cls = mirror.runtimeClass(typeTag[T].tpe)
-
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
- val constructExpression = ScalaReflection.constructorFor[T]
-
- new ClassEncoder[T](
- extractExpression.dataType,
- extractExpression.flatten,
- constructExpression,
- ClassTag[T](cls))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e9cc00a2b6..0b42130a01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
* internal binary representation.
*/
object RowEncoder {
- def apply(schema: StructType): ClassEncoder[Row] = {
+ def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = extractorsFor(inputObject, schema)
val constructExpression = constructorFor(schema)
- new ClassEncoder[Row](
+ new ExpressionEncoder[Row](
schema,
+ flat = false,
extractExpressions.asInstanceOf[CreateStruct].children,
constructExpression,
ClassTag(cls))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 52f8383fac..d4642a5006 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -15,29 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql.catalyst
-import org.apache.spark.SparkFunSuite
-
-class PrimitiveEncoderSuite extends SparkFunSuite {
- test("long encoder") {
- val enc = new LongEncoder()
- val row = enc.toRow(10)
- assert(row.getLong(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("int encoder") {
- val enc = new IntEncoder()
- val row = enc.toRow(10)
- assert(row.getInt(0) == 10)
- assert(enc.fromRow(row) == 10)
- }
-
- test("string encoder") {
- val enc = new StringEncoder()
- val row = enc.toRow("test")
- assert(row.getString(0) == "test")
- assert(enc.fromRow(row) == "test")
+package object encoders {
+ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
+ case e: ExpressionEncoder[A] => e
+ case _ => sys.error(s"Only expression encoders are supported today")
}
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
deleted file mode 100644
index a93f2d7c61..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-
-/** An encoder for primitive Long types. */
-case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Long] = ClassTag.Long
- override def schema: StructType =
- StructType(StructField(fieldName, LongType) :: Nil)
-
- override def fromRow(row: InternalRow): Long = row.getLong(ordinal)
-
- override def toRow(t: Long): InternalRow = {
- row.setLong(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this
- override def bind(schema: Seq[Attribute]): Encoder[Long] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this
-}
-
-/** An encoder for primitive Integer types. */
-case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] {
- private val row = UnsafeRow.createFromByteArray(64, 1)
-
- override def clsTag: ClassTag[Int] = ClassTag.Int
- override def schema: StructType =
- StructType(StructField(fieldName, IntegerType) :: Nil)
-
- override def fromRow(row: InternalRow): Int = row.getInt(ordinal)
-
- override def toRow(t: Int): InternalRow = {
- row.setInt(ordinal, t)
- row
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this
- override def bind(schema: Seq[Attribute]): Encoder[Int] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this
-}
-
-/** An encoder for String types. */
-case class StringEncoder(
- fieldName: String = "value",
- ordinal: Int = 0) extends Encoder[String] {
-
- val record = new SpecificMutableRow(StringType :: Nil)
-
- @transient
- lazy val projection =
- GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil)
-
- override def schema: StructType =
- StructType(
- StructField("value", StringType, nullable = false) :: Nil)
-
- override def clsTag: ClassTag[String] = scala.reflect.classTag[String]
-
-
- override final def fromRow(row: InternalRow): String = {
- row.getString(ordinal)
- }
-
- override final def toRow(value: String): InternalRow = {
- val utf8String = UTF8String.fromString(value)
- record(0) = utf8String
- // TODO: this is a bit of a hack to produce UnsafeRows
- projection(record)
- }
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this
- override def bind(schema: Seq[Attribute]): Encoder[String] = this
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
deleted file mode 100644
index a48eeda7d2..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala
+++ /dev/null
@@ -1,173 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.types.{StructField, StructType}
-
-// Most of this file is codegen.
-// scalastyle:off
-
-/**
- * A set of composite encoders that take sub encoders and map each of their objects to a
- * Scala tuple. Note that currently the implementation is fairly limited and only supports going
- * from an internal row to a tuple.
- */
-object TupleEncoder {
-
- /** Code generator for composite tuple encoders. */
- def main(args: Array[String]): Unit = {
- (2 to 5).foreach { i =>
- val types = (1 to i).map(t => s"T$t").mkString(", ")
- val tupleType = s"($types)"
- val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ")
- val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ")
- val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ")
-
- println(
- s"""
- |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] {
- | val schema = StructType(Array($fields))
- |
- | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType]
- |
- | def fromRow(row: InternalRow): $tupleType = {
- | ($fromRow)
- | }
- |
- | override def toRow(t: $tupleType): InternalRow =
- | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
- |
- | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = {
- | this
- | }
- |
- | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |
- |
- | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] =
- | throw new UnsupportedOperationException("Tuple Encoders only support bind.")
- |}
- """.stripMargin)
- }
- }
-}
-
-class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema)))
-
- def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)]
-
- def fromRow(row: InternalRow): (T1, T2) = {
- (e1.fromRow(row), e2.fromRow(row))
- }
-
- override def toRow(t: (T1, T2)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)]
-
- def fromRow(row: InternalRow): (T1, T2, T3) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
-
-
-class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] {
- val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema)))
-
- def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)]
-
- def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = {
- (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row))
- }
-
- override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow =
- throw new UnsupportedOperationException("Tuple Encoders only support fromRow.")
-
- override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = {
- this
- }
-
- override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-
-
- override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] =
- throw new UnsupportedOperationException("Tuple Encoders only support bind.")
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 21a55a5371..d2d3db0a44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
import org.apache.spark.sql.catalyst.plans._
@@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode {
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def missingInput: AttributeSet = AttributeSet.empty
@@ -460,8 +460,8 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
- val attrs = implicitly[Encoder[U]].schema.toAttributes
- new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child)
+ val attrs = encoderFor[U].schema.toAttributes
+ new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
}
}
@@ -472,8 +472,8 @@ object AppendColumn {
*/
case class AppendColumn[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
@@ -488,11 +488,11 @@ object MapGroups {
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- implicitly[Encoder[K]],
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
+ encoderFor[K],
+ encoderFor[T],
+ encoderFor[U],
groupingAttributes,
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[U].schema.toAttributes,
child)
}
}
@@ -504,9 +504,9 @@ object MapGroups {
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 008d0bea8a..a374da4da1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -47,7 +47,16 @@ case class RepeatedData(
case class SpecificCollection(l: List[Int])
-class ProductEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends SparkFunSuite {
+
+ encodeDecodeTest(1)
+ encodeDecodeTest(1L)
+ encodeDecodeTest(1.toDouble)
+ encodeDecodeTest(1.toFloat)
+ encodeDecodeTest(true)
+ encodeDecodeTest(false)
+ encodeDecodeTest(1.toShort)
+ encodeDecodeTest(1.toByte)
encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
@@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite {
{ (l, r) => l._2.toString == r._2.toString }
/** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
- protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+ protected def encodeDecodeTest[T : TypeTag](inputData: T) =
encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
/**
* Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
* matches the original.
*/
- protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+ protected def encodeDecodeTestCustom[T : TypeTag](
inputData: T)(
c: (T, T) => Boolean) = {
- test(s"encode/decode: $inputData") {
- val encoder = try ProductEncoder[T] catch {
+ test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
+ val encoder = try ExpressionEncoder[T]() catch {
case e: Exception =>
fail(s"Exception thrown generating encoder", e)
}
val convertedData = encoder.toRow(inputData)
val schema = encoder.schema.toAttributes
- val boundEncoder = encoder.bind(schema)
+ val boundEncoder = encoder.resolve(schema).bind(schema)
val convertedBack = try boundEncoder.fromRow(convertedData) catch {
case e: Exception =>
fail(
@@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite {
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
+ |Encoder:
+ |$boundEncoder
|
""".stripMargin, e)
}
if (!c(inputData, convertedBack)) {
- val types =
- convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ val types = convertedBack match {
+ case c: Product =>
+ c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+ case other => other.getClass.getName
+ }
+
val encodedData = try {
convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
@@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite {
|${encoder.schema.treeString}
|
|Extract Expressions:
- |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
- |
- |Construct Expressions:
- |${boundEncoder.constructExpression.treeString}
- |
+ |$boundEncoder
""".stripMargin)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b1d9..aa817a037e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7630..e0ab5f593e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0
*/
@Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
@transient val sqlContext: SQLContext,
- @transient val queryExecution: QueryExecution)(
- implicit val encoder: Encoder[T]) extends Serializable {
+ @transient val queryExecution: QueryExecution,
+ unresolvedEncoder: Encoder[T]) extends Serializable {
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+ case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+ case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+ }
private implicit def classTag = encoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan))
+ this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
/** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
* TODO: document binding rules
* @since 1.6.0
*/
- def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+ def as[U : Encoder]: Dataset[U] = {
+ new Dataset(sqlContext, queryExecution, encoderFor[U])
+ }
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = implicitly[Encoder[T]]
+ val tEnc = encoderFor[T]
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
sqlContext,
MapPartitions[T, U](
func,
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[T],
+ encoderFor[U],
+ encoderFor[U].schema.toAttributes,
logicalPlan))
}
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
- implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
- implicitly[Encoder[T]].bind(inputPlan.output),
+ encoderFor[K].resolve(withGroupingKey.newColumns),
+ encoderFor[T].bind(inputPlan.output),
executed,
inputPlan.output,
withGroupingKey.newColumns)
@@ -221,6 +230,18 @@ class Dataset[T] private[sql](
* ****************** */
/**
+ * Selects a set of column based expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ * @group dfops
+ * @since 1.3.0
+ */
+ // Copied from Dataframe to make sure we don't have invalid overloads.
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
+ /**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
*
* {{{
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
- // Codegen
- // scalastyle:off
-
- /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
- private def genSelect: String = {
- (2 to 5).map { n =>
- val types = (1 to n).map(i =>s"U$i").mkString(", ")
- val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
- val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
- val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
- s"""
- |/**
- | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- | * @since 1.6.0
- | */
- |def select[$types]($args): Dataset[($types)] = {
- | implicit val te = new Tuple${n}Encoder($encoders)
- | new Dataset[($types)](sqlContext,
- | Project(
- | $schema :: Nil,
- | logicalPlan))
- |}
- |
- """.stripMargin
- }.mkString("\n")
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+ val unresolvedPlan = Project(aliases, logicalPlan)
+ val execution = new QueryExecution(sqlContext, unresolvedPlan)
+ // Rebind the encoders to the nested schema that will be produced by the select.
+ val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+ case (e: ExpressionEncoder[_], a) if !e.flat =>
+ e.nested(a.toAttribute).resolve(execution.analyzed.output)
+ case (e, a) =>
+ e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+ }
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
- implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
- new Dataset[(U1, U2)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
- implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
- new Dataset[(U1, U2, U3)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
- implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
- new Dataset[(U1, U2, U3, U4)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
- implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
- new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
- logicalPlan))
- }
-
- // scalastyle:on
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4],
+ c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *
* Set operations *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
*/
def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+ /* ****** *
+ * Joins *
+ * ****** */
+
+ /**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val leftData = this.encoder match {
+ case e if e.flat => Alias(left.output.head, "_1")()
+ case _ => Alias(CreateStruct(left.output), "_1")()
+ }
+ val rightData = other.encoder match {
+ case e if e.flat => Alias(right.output.head, "_2")()
+ case _ => Alias(CreateStruct(right.output), "_2")()
+ }
+ val leftEncoder =
+ if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+ val rightEncoder =
+ if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+ withPlan[(T, U)](other) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ Join(left, right, Inner, Some(condition.expr)))
+ }
+ }
+
/* ************************** *
* Gather to Driver Actions *
* ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
private[sql] def withPlan[R : Encoder](
other: Dataset[_])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](
- sqlContext,
- sqlContext.executePlan(
- f(logicalPlan, other.logicalPlan)))
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f974..2cb94430e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- val enc = implicitly[Encoder[T]]
+ val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474df0d..f460a86414 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
- implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
- implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
- implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(s))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba5bd..89938471ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
*/
case class AppendColumns[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 08496249c6..aebb390a1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
data: _*)
}
+ test("as tuple") {
+ val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+ checkAnswer(
+ data.as[(String, Int)],
+ ("a", 1), ("b", 2))
+ }
+
test("as case class / collect") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
- test("select 3") {
+ test("select 2") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
ds.select(
expr("_1").as[String],
- expr("_2").as[Int],
- expr("_2 + 1").as[Int]),
- ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ expr("_2").as[Int]) : Dataset[(String, Int)],
+ ("a", 1), ("b", 2), ("c", 3))
+ }
+
+ test("select 2, primitive and tuple") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("struct(_2, _2)").as[(Int, Int)]),
+ ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+ }
+
+ test("select 2, primitive and class") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
+ }
+
+ test("select 2, primitive and class, fields reordered") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkDecoding(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
}
test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
+ test("joinWith, flat schema") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a.value" === $"b.value"),
+ (1, 1), (2, 2))
+ }
+
+ test("joinWith, expression condition") {
+ val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"_1" === $"a"),
+ (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+ }
+
+ test("joinWith tuple with primitive, expression") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"_2"),
+ (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+ }
+
+ test("joinWith class with primitive, toDF") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
+ Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+ }
+
+ test("multi-level joinWith") {
+ val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
+ ((("a", 1), ("a", 1)), ("a", 1)),
+ ((("b", 2), ("b", 2)), ("b", 2)))
+
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupBy(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba567512f..73e02eb0d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest {
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
}
}
- protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ * - Special handling is done based on whether the query plan should be expected to return
+ * the results in sorted order.
+ * - This function also checks to make sure that the schema for serializing the expected answer
+ * matches that produced by the dataset (i.e. does manual construction of object match
+ * the constructed encoder for cases like joins, etc). Note that this means that it will fail
+ * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
+ * which performs a subset of the checks done by this function.
+ */
+ protected def checkAnswer[T : Encoder](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+ checkDecoding(ds, expectedAnswer: _*)
+ }
+
+ protected def checkDecoding[T](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val decoded = try ds.collect().toSet catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.encoder}
+ |${ds.encoder.constructExpression.treeString}
+ |${ds.queryExecution}
+ """.stripMargin, e)
+ }
+
+ if (decoded != expectedAnswer.toSet) {
+ fail(
+ s"""Decoded objects do not match expected objects:
+ |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |${ds.encoder.constructExpression.treeString}
+ """.stripMargin)
+ }
}
/**