aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala190
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala89
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala44
7 files changed, 259 insertions, 99 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 32d9b0b1d9..aa817a037e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -267,7 +267,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96213c7630..e0ab5f593e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.StructType
@@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0
*/
@Experimental
-class Dataset[T] private[sql](
+class Dataset[T] private(
@transient val sqlContext: SQLContext,
- @transient val queryExecution: QueryExecution)(
- implicit val encoder: Encoder[T]) extends Serializable {
+ @transient val queryExecution: QueryExecution,
+ unresolvedEncoder: Encoder[T]) extends Serializable {
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
+ case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
+ case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
+ }
private implicit def classTag = encoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan))
+ this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
/** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
def schema: StructType = encoder.schema
@@ -76,7 +83,9 @@ class Dataset[T] private[sql](
* TODO: document binding rules
* @since 1.6.0
*/
- def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+ def as[U : Encoder]: Dataset[U] = {
+ new Dataset(sqlContext, queryExecution, encoderFor[U])
+ }
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
@@ -103,7 +112,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = implicitly[Encoder[T]]
+ val tEnc = encoderFor[T]
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
@@ -150,9 +159,9 @@ class Dataset[T] private[sql](
sqlContext,
MapPartitions[T, U](
func,
- implicitly[Encoder[T]],
- implicitly[Encoder[U]],
- implicitly[Encoder[U]].schema.toAttributes,
+ encoderFor[T],
+ encoderFor[U],
+ encoderFor[U].schema.toAttributes,
logicalPlan))
}
@@ -209,8 +218,8 @@ class Dataset[T] private[sql](
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
- implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
- implicitly[Encoder[T]].bind(inputPlan.output),
+ encoderFor[K].resolve(withGroupingKey.newColumns),
+ encoderFor[T].bind(inputPlan.output),
executed,
inputPlan.output,
withGroupingKey.newColumns)
@@ -221,6 +230,18 @@ class Dataset[T] private[sql](
* ****************** */
/**
+ * Selects a set of column based expressions.
+ * {{{
+ * df.select($"colA", $"colB" + 1)
+ * }}}
+ * @group dfops
+ * @since 1.3.0
+ */
+ // Copied from Dataframe to make sure we don't have invalid overloads.
+ @scala.annotation.varargs
+ def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+
+ /**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
*
* {{{
@@ -233,88 +254,64 @@ class Dataset[T] private[sql](
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
- // Codegen
- // scalastyle:off
-
- /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
- private def genSelect: String = {
- (2 to 5).map { n =>
- val types = (1 to n).map(i =>s"U$i").mkString(", ")
- val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
- val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
- val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
- s"""
- |/**
- | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- | * @since 1.6.0
- | */
- |def select[$types]($args): Dataset[($types)] = {
- | implicit val te = new Tuple${n}Encoder($encoders)
- | new Dataset[($types)](sqlContext,
- | Project(
- | $schema :: Nil,
- | logicalPlan))
- |}
- |
- """.stripMargin
- }.mkString("\n")
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+ val unresolvedPlan = Project(aliases, logicalPlan)
+ val execution = new QueryExecution(sqlContext, unresolvedPlan)
+ // Rebind the encoders to the nested schema that will be produced by the select.
+ val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+ case (e: ExpressionEncoder[_], a) if !e.flat =>
+ e.nested(a.toAttribute).resolve(execution.analyzed.output)
+ case (e, a) =>
+ e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
+ }
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
- implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
- new Dataset[(U1, U2)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
- implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
- new Dataset[(U1, U2, U3)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
- implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
- new Dataset[(U1, U2, U3, U4)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
- logicalPlan))
- }
-
-
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
- def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
- implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
- new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
- Project(
- Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
- logicalPlan))
- }
-
- // scalastyle:on
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[U1],
+ c2: TypedColumn[U2],
+ c3: TypedColumn[U3],
+ c4: TypedColumn[U4],
+ c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *
* Set operations *
@@ -360,6 +357,48 @@ class Dataset[T] private[sql](
*/
def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+ /* ****** *
+ * Joins *
+ * ****** */
+
+ /**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val leftData = this.encoder match {
+ case e if e.flat => Alias(left.output.head, "_1")()
+ case _ => Alias(CreateStruct(left.output), "_1")()
+ }
+ val rightData = other.encoder match {
+ case e if e.flat => Alias(right.output.head, "_2")()
+ case _ => Alias(CreateStruct(right.output), "_2")()
+ }
+ val leftEncoder =
+ if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
+ val rightEncoder =
+ if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(leftEncoder, rightEncoder)
+
+ withPlan[(T, U)](other) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ Join(left, right, Inner, Some(condition.expr)))
+ }
+ }
+
/* ************************** *
* Gather to Driver Actions *
* ************************** */
@@ -380,13 +419,10 @@ class Dataset[T] private[sql](
private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
private[sql] def withPlan[R : Encoder](
other: Dataset[_])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](
- sqlContext,
- sqlContext.executePlan(
- f(logicalPlan, other.logicalPlan)))
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5e7198f974..2cb94430e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -491,7 +491,7 @@ class SQLContext private[sql](
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- val enc = implicitly[Encoder[T]]
+ val enc = encoderFor[T]
val attributes = enc.schema.toAttributes
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index af8474df0d..f460a86414 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
- implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
- implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
- implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
+ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
+ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
+ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
+ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
+ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
+ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(s))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2bb3dba5bd..89938471ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
@@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
*/
case class MapPartitions[T, U](
func: Iterator[T] => Iterator[U],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -337,8 +337,8 @@ case class MapPartitions[T, U](
*/
case class AppendColumns[T, U](
func: T => U,
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
@@ -363,9 +363,9 @@ case class AppendColumns[T, U](
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
- kEncoder: Encoder[K],
- tEncoder: Encoder[T],
- uEncoder: Encoder[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
+ uEncoder: ExpressionEncoder[U],
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 08496249c6..aebb390a1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
data: _*)
}
+ test("as tuple") {
+ val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+ checkAnswer(
+ data.as[(String, Int)],
+ ("a", 1), ("b", 2))
+ }
+
test("as case class / collect") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
- test("select 3") {
+ test("select 2") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
ds.select(
expr("_1").as[String],
- expr("_2").as[Int],
- expr("_2 + 1").as[Int]),
- ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ expr("_2").as[Int]) : Dataset[(String, Int)],
+ ("a", 1), ("b", 2), ("c", 3))
+ }
+
+ test("select 2, primitive and tuple") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("struct(_2, _2)").as[(Int, Int)]),
+ ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+ }
+
+ test("select 2, primitive and class") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
+ }
+
+ test("select 2, primitive and class, fields reordered") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkDecoding(
+ ds.select(
+ expr("_1").as[String],
+ expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+ ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3)))
}
test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
+ test("joinWith, flat schema") {
+ val ds1 = Seq(1, 2, 3).toDS().as("a")
+ val ds2 = Seq(1, 2).toDS().as("b")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a.value" === $"b.value"),
+ (1, 1), (2, 2))
+ }
+
+ test("joinWith, expression condition") {
+ val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"_1" === $"a"),
+ (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+ }
+
+ test("joinWith tuple with primitive, expression") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"_2"),
+ (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+ }
+
+ test("joinWith class with primitive, toDF") {
+ val ds1 = Seq(1, 1, 2).toDS()
+ val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"),
+ Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+ }
+
+ test("multi-level joinWith") {
+ val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+ val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+ val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+ checkAnswer(
+ ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
+ ((("a", 1), ("a", 1)), ("a", 1)),
+ ((("b", 2), ("b", 2)), ("b", 2)))
+
+ }
+
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupBy(v => (1, v._2))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba567512f..73e02eb0d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest {
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
}
}
- protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ /**
+ * Evaluates a dataset to make sure that the result of calling collect matches the given
+ * expected answer.
+ * - Special handling is done based on whether the query plan should be expected to return
+ * the results in sorted order.
+ * - This function also checks to make sure that the schema for serializing the expected answer
+ * matches that produced by the dataset (i.e. does manual construction of object match
+ * the constructed encoder for cases like joins, etc). Note that this means that it will fail
+ * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
+ * which performs a subset of the checks done by this function.
+ */
+ protected def checkAnswer[T : Encoder](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+ checkDecoding(ds, expectedAnswer: _*)
+ }
+
+ protected def checkDecoding[T](
+ ds: => Dataset[T],
+ expectedAnswer: T*): Unit = {
+ val decoded = try ds.collect().toSet catch {
+ case e: Exception =>
+ fail(
+ s"""
+ |Exception collecting dataset as objects
+ |${ds.encoder}
+ |${ds.encoder.constructExpression.treeString}
+ |${ds.queryExecution}
+ """.stripMargin, e)
+ }
+
+ if (decoded != expectedAnswer.toSet) {
+ fail(
+ s"""Decoded objects do not match expected objects:
+ |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+ |${ds.encoder.constructExpression.treeString}
+ """.stripMargin)
+ }
}
/**