aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-10-27 13:28:52 -0700
committerYin Huai <yhuai@databricks.com>2015-10-27 13:28:52 -0700
commit5a5f65905a202e59bc85170b01c57a883718ddf6 (patch)
treedcd1f9958573a0e3b419805609495fc1380b1565 /sql/core
parent3bdbbc6c972567861044dd6a6dc82f35cd12442d (diff)
downloadspark-5a5f65905a202e59bc85170b01c57a883718ddf6.tar.gz
spark-5a5f65905a202e59bc85170b01c57a883718ddf6.tar.bz2
spark-5a5f65905a202e59bc85170b01c57a883718ddf6.zip
[SPARK-11347] [SQL] Support for joinWith in Datasets
This PR adds a new operation `joinWith` to a `Dataset`, which returns a `Tuple` for each pair where a given `condition` evaluates to true. ```scala case class ClassData(a: String, b: Int) val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() > ds1.joinWith(ds2, $"_1" === $"a").collect() res0: Array((ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) ``` This operation is similar to the relation `join` function with one important difference in the result schema. Since `joinWith` preserves objects present on either side of the join, the result schema is similarly nested into a tuple under the column names `_1` and `_2`. This type of join can be useful both for preserving type-safety with the original object types as well as working with relational data where either side of the join has column names in common. ## Required Changes to Encoders In the process of working on this patch, several deficiencies to the way that we were handling encoders were discovered. Specifically, it turned out to be very difficult to `rebind` the non-expression based encoders to extract the nested objects from the results of joins (and also typed selects that return tuples). As a result the following changes were made. - `ClassEncoder` has been renamed to `ExpressionEncoder` and has been improved to also handle primitive types. Additionally, it is now possible to take arbitrary expression encoders and rewrite them into a single encoder that returns a tuple. - All internal operations on `Dataset`s now require an `ExpressionEncoder`. If the users tries to pass a non-`ExpressionEncoder` in, an error will be thrown. We can relax this requirement in the future by constructing a wrapper class that uses expressions to project the row to the expected schema, shielding the users code from the required remapping. This will give us a nice balance where we don't force user encoders to understand attribute references and binding, but still allow our native encoder to leverage runtime code generation to construct specific encoders for a given schema that avoid an extra remapping step. - Additionally, the semantics for different types of objects are now better defined. As stated in the `ExpressionEncoder` scaladoc: - Classes will have their sub fields extracted by name using `UnresolvedAttribute` expressions and `UnresolvedExtractValue` expressions. - Tuples will have their subfields extracted by position using `BoundReference` expressions. - Primitives will have their values extracted from the first ordinal with a schema that defaults to the name `value`. - Finally, the binding lifecycle for `Encoders` has now been unified across the codebase. Encoders are now `resolved` to the appropriate schema in the constructor of `Dataset`. This process replaces an unresolved expressions with concrete `AttributeReference` expressions. Binding then happens on demand, when an encoder is going to be used to construct an object. This closely mirrors the lifecycle for standard expressions when executing normal SQL or `DataFrame` queries. Author: Michael Armbrust <michael@databricks.com> Closes #9300 from marmbrus/datasets-tuples.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala190
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala44
7 files changed, 259 insertions, 99 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b1d9..aa817a037e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7630..e0ab5f593e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0
*/
@Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
@transient val sqlContext: SQLContext,
- @transient val queryExecution: QueryExecution)(
- implicit val encoder: Encoder[T]) extends Serializable {
+ @transient val queryExecution: QueryExecution,
+ unresolvedEncoder: Encoder[T]) extends Serializable {
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+ case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+ case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+ }
private implicit def classTag = encoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan))
+ this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
/** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
* TODO: document binding rules
* @since 1.6.0
*/
- def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+ def as[U : Encoder]: Dataset[U] = {
+ new Dataset(sqlContext, queryExecution, encoderFor[U])
+ }
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = implicitly[Encoder[T]]
+ val tEnc = encoderFor[T]
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
sqlContext,
MapPartitions[T, U](
func,
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[T],
+ encoderFor[U],
+ encoderFor[U].schema.toAttributes,
logicalPlan))
}
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
- implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
- implicitly[Encoder[T]].bind(inputPlan.output),
+ encoderFor[K].resolve(withGroupingKey.newColumns),
+ encoderFor[T].bind(inputPlan.output),
executed,
inputPlan.output,
withGroupingKey.newColumns)
@@ -221,6 +230,18 @@ class Dataset[T] private[sql](
* ****************** */
/**
+ * Selects a set of column based expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ * @group dfops
+ * @since 1.3.0
+ */
+ // Copied from Dataframe to make sure we don't have invalid overloads.
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
+ /**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
*
* {{{
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
- // Codegen
- // scalastyle:off
-
- /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
- private def genSelect: String = {
- (2 to 5).map { n =>
- val types = (1 to n).map(i =>s"U$i").mkString(", ")
- val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
- val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
- val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
- s"""
- |/**
- | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- | * @since 1.6.0
- | */
- |def select[$types]($args): Dataset[($types)] = {
- | implicit val te = new Tuple${n}Encoder($encoders)
- | new Dataset[($types)](sqlContext,
- | Project(
- | $schema :: Nil,
- | logicalPlan))
- |}
- |
- """.stripMargin
- }.mkString("\n")
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+ val unresolvedPlan = Project(aliases, logicalPlan)
+ val execution = new QueryExecution(sqlContext, unresolvedPlan)
+ // Rebind the encoders to the nested schema that will be produced by the select.
+ val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+ case (e: ExpressionEncoder[_], a) if !e.flat =>
+ e.nested(a.toAttribute).resolve(execution.analyzed.output)
+ case (e, a) =>
+ e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+ }
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
- implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
- new Dataset[(U1, U2)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
- implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
- new Dataset[(U1, U2, U3)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
- implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
- new Dataset[(U1, U2, U3, U4)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
- implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
- new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
- logicalPlan))
- }
-
- // scalastyle:on
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4],
+ c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *
* Set operations *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
*/
def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+ /* ****** *
+ * Joins *
+ * ****** */
+
+ /**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val leftData = this.encoder match {
+ case e if e.flat => Alias(left.output.head, "_1")()
+ case _ => Alias(CreateStruct(left.output), "_1")()
+ }
+ val rightData = other.encoder match {
+ case e if e.flat => Alias(right.output.head, "_2")()
+ case _ => Alias(CreateStruct(right.output), "_2")()
+ }
+ val leftEncoder =
+ if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+ val rightEncoder =
+ if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+ withPlan[(T, U)](other) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ Join(left, right, Inner, Some(condition.expr)))
+ }
+ }
+
/* ************************** *
* Gather to Driver Actions *
* ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
private[sql] def withPlan[R : Encoder](
other: Dataset[_])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](
- sqlContext,
- sqlContext.executePlan(
- f(logicalPlan, other.logicalPlan)))
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f974..2cb94430e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- val enc = implicitly[Encoder[T]]
+ val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474df0d..f460a86414 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
- implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
- implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
- implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(s))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba5bd..89938471ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
*/
case class AppendColumns[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 08496249c6..aebb390a1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
data: _*)
}
+ test("as tuple") {
+ val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+ checkAnswer(
+ data.as[(String, Int)],
+ ("a", 1), ("b", 2))
+ }
+
test("as case class / collect") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
- test("select 3") {
+ test("select 2") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
ds.select(
expr("_1").as[String],
- expr("_2").as[Int],
- expr("_2 + 1").as[Int]),
- ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ expr("_2").as[Int]) : Dataset[(String, Int)],
+ ("a", 1), ("b", 2), ("c", 3))
+ }
+
+ test("select 2, primitive and tuple") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("struct(_2, _2)").as[(Int, Int)]),
+ ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+ }
+
+ test("select 2, primitive and class") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
+ }
+
+ test("select 2, primitive and class, fields reordered") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkDecoding(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
}
test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
+ test("joinWith, flat schema") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a.value" === $"b.value"),
+ (1, 1), (2, 2))
+ }
+
+ test("joinWith, expression condition") {
+ val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"_1" === $"a"),
+ (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+ }
+
+ test("joinWith tuple with primitive, expression") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"_2"),
+ (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+ }
+
+ test("joinWith class with primitive, toDF") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
+ Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+ }
+
+ test("multi-level joinWith") {
+ val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
+ ((("a", 1), ("a", 1)), ("a", 1)),
+ ((("b", 2), ("b", 2)), ("b", 2)))
+
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupBy(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba567512f..73e02eb0d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest {
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
}
}
- protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ * - Special handling is done based on whether the query plan should be expected to return
+ * the results in sorted order.
+ * - This function also checks to make sure that the schema for serializing the expected answer
+ * matches that produced by the dataset (i.e. does manual construction of object match
+ * the constructed encoder for cases like joins, etc). Note that this means that it will fail
+ * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
+ * which performs a subset of the checks done by this function.
+ */
+ protected def checkAnswer[T : Encoder](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+ checkDecoding(ds, expectedAnswer: _*)
+ }
+
+ protected def checkDecoding[T](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val decoded = try ds.collect().toSet catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.encoder}
+ |${ds.encoder.constructExpression.treeString}
+ |${ds.queryExecution}
+ """.stripMargin, e)
+ }
+
+ if (decoded != expectedAnswer.toSet) {
+ fail(
+ s"""Decoded objects do not match expected objects:
+ |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |${ds.encoder.constructExpression.treeString}
+ """.stripMargin)
+ }
}
/**