aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-10-22 15:20:17 -0700
committerReynold Xin <rxin@databricks.com>2015-10-22 15:20:17 -0700
commit53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc (patch)
tree9e10bf6e96c5faaf51d52790acdd9adc71145b54 /sql/core
parent188ea348fdcf877d86f3c433cd15f6468fe3b42a (diff)
downloadspark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.tar.gz
spark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.tar.bz2
spark-53e83a3a77cafc2ccd0764ecdb8b3ba735bc51fc.zip
[SPARK-11116][SQL] First Draft of Dataset API
*This PR adds a new experimental API to Spark, tentitively named Datasets.* A `Dataset` is a strongly-typed collection of objects that can be transformed in parallel using functional or relational operations. Example usage is as follows: ### Functional ```scala > val ds: Dataset[Int] = Seq(1, 2, 3).toDS() > ds.filter(_ % 1 == 0).collect() res1: Array[Int] = Array(1, 2, 3) ``` ### Relational ```scala scala> ds.toDF().show() +-----+ |value| +-----+ | 1| | 2| | 3| +-----+ > ds.select(expr("value + 1").as[Int]).collect() res11: Array[Int] = Array(2, 3, 4) ``` ## Comparison to RDDs A `Dataset` differs from an `RDD` in the following ways: - The creation of a `Dataset` requires the presence of an explicit `Encoder` that can be used to serialize the object into a binary format. Encoders are also capable of mapping the schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime reflection based serialization. - Internally, a `Dataset` is represented by a Catalyst logical plan and the data is stored in the encoded form. This representation allows for additional logical operations and enables many operations (sorting, shuffling, etc.) to be performed without deserializing to an object. A `Dataset` can be converted to an `RDD` by calling the `.rdd` method. ## Comparison to DataFrames A `Dataset` can be thought of as a specialized DataFrame, where the elements map to a specific JVM object type, instead of to a generic `Row` container. A DataFrame can be transformed into specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed `Dataset` to a generic DataFrame by calling `ds.toDF()`. ## Implementation Status and TODOs This is a rough cut at the least controversial parts of the API. The primary purpose here is to get something committed so that we can better parallelize further work and get early feedback on the API. The following is being deferred to future PRs: - Joins and Aggregations (prototype here https://github.com/apache/spark/commit/f11f91e6f08c8cf389b8388b626cd29eec32d937) - Support for Java Additionally, the responsibility for binding an encoder to a given schema is currently done in a fairly ad-hoc fashion. This is an internal detail, and what we are doing today works for the cases we care about. However, as we add more APIs we'll probably need to do this in a more principled way (i.e. separate resolution from binding as we do in DataFrames). ## COMPATIBILITY NOTE Long term we plan to make `DataFrame` extend `Dataset[Row]`. However, making this change to che class hierarchy would break the function signatures for the existing function operations (map, flatMap, etc). As such, this class should be considered a preview of the final API. Changes will be made to the interface after Spark 1.6. Author: Michael Armbrust <michael@databricks.com> Closes #9190 from marmbrus/dataset-infra.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala392
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala68
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala141
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala79
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
13 files changed, 1006 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 37d559c8e4..de11a1699a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql
+
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -36,6 +38,11 @@ private[sql] object Column {
def unapply(col: Column): Option[Expression] = Some(col.expr)
}
+/**
+ * A [[Column]] where an [[Encoder]] has been given for the expected return type.
+ * @since 1.6.0
+ */
+class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr)
/**
* :: Experimental ::
@@ -70,6 +77,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
override def hashCode: Int = this.expr.hashCode
/**
+ * Provides a type hint about the expected return value of this column. This information can
+ * be used by operations such as `select` on a [[Dataset]] to automatically convert the
+ * results into the correct JVM types.
+ * @since 1.6.0
+ */
+ def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
+
+ /**
* Extracts a value or values from a complex type.
* The following types of extraction are supported:
* - Given an Array, an integer ordinal can be used to retrieve a single value.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 2f10aa9f3c..bf25bcde20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
@@ -259,6 +260,16 @@ class DataFrame private[sql](
def toDF(): DataFrame = this
/**
+ * :: Experimental ::
+ * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the
+ * specified type, `U`.
+ * @group basic
+ * @since 1.6.0
+ */
+ @Experimental
+ def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution)
+
+ /**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
new file mode 100644
index 0000000000..96213c7630
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel
+ * using functional or relational operations.
+ *
+ * A [[Dataset]] differs from an [[RDD]] in the following ways:
+ * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored
+ * in the encoded form. This representation allows for additional logical operations and
+ * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to
+ * an object.
+ * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be
+ * used to serialize the object into a binary format. Encoders are also capable of mapping the
+ * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime
+ * reflection based serialization. Operations that change the type of object stored in the
+ * dataset also need an encoder for the new type.
+ *
+ * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific
+ * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into
+ * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed
+ * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`.
+ *
+ * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However,
+ * making this change to the class hierarchy would break the function signatures for the existing
+ * functional operations (map, flatMap, etc). As such, this class should be considered a preview
+ * of the final API. Changes will be made to the interface after Spark 1.6.
+ *
+ * @since 1.6.0
+ */
+@Experimental
+class Dataset[T] private[sql](
+ @transient val sqlContext: SQLContext,
+ @transient val queryExecution: QueryExecution)(
+ implicit val encoder: Encoder[T]) extends Serializable {
+
+ private implicit def classTag = encoder.clsTag
+
+ private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
+ this(sqlContext, new QueryExecution(sqlContext, plan))
+
+ /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
+ def schema: StructType = encoder.schema
+
+ /* ************* *
+ * Conversions *
+ * ************* */
+
+ /**
+ * Returns a new `Dataset` where each record has been mapped on to the specified type.
+ * TODO: should bind here...
+ * TODO: document binding rules
+ * @since 1.6.0
+ */
+ def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]])
+
+ /**
+ * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
+ * the same name after two Datasets have been joined.
+ */
+ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _))
+
+ /**
+ * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
+ * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
+ * objects that allow fields to be accessed by ordinal or name.
+ */
+ def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
+
+
+ /**
+ * Returns this Dataset.
+ * @since 1.6.0
+ */
+ def toDS(): Dataset[T] = this
+
+ /**
+ * Converts this Dataset to an RDD.
+ * @since 1.6.0
+ */
+ def rdd: RDD[T] = {
+ val tEnc = implicitly[Encoder[T]]
+ val input = queryExecution.analyzed.output
+ queryExecution.toRdd.mapPartitions { iter =>
+ val bound = tEnc.bind(input)
+ iter.map(bound.fromRow)
+ }
+ }
+
+ /* *********************** *
+ * Functional Operations *
+ * *********************** */
+
+ /**
+ * Concise syntax for chaining custom transformations.
+ * {{{
+ * def featurize(ds: Dataset[T]) = ...
+ *
+ * dataset
+ * .transform(featurize)
+ * .transform(...)
+ * }}}
+ *
+ * @since 1.6.0
+ */
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+ /**
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+
+ /**
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+
+ /**
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+ new Dataset(
+ sqlContext,
+ MapPartitions[T, U](
+ func,
+ implicitly[Encoder[T]],
+ implicitly[Encoder[U]],
+ implicitly[Encoder[U]].schema.toAttributes,
+ logicalPlan))
+ }
+
+ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+ mapPartitions(_.flatMap(func))
+
+ /* ************** *
+ * Side effects *
+ * ************** */
+
+ /**
+ * Runs `func` on each element of this Dataset.
+ * @since 1.6.0
+ */
+ def foreach(func: T => Unit): Unit = rdd.foreach(func)
+
+ /**
+ * Runs `func` on each partition of this Dataset.
+ * @since 1.6.0
+ */
+ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
+
+ /* ************* *
+ * Aggregation *
+ * ************* */
+
+ /**
+ * Reduces the elements of this Dataset using the specified binary function. The given function
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: (T, T) => T): T = rdd.reduce(func)
+
+ /**
+ * Aggregates the elements of each partition, and then the results for all the partitions, using a
+ * given associative and commutative function and a neutral "zero value".
+ *
+ * This behaves somewhat differently than the fold operations implemented for non-distributed
+ * collections in functional languages like Scala. This fold operation may be applied to
+ * partitions individually, and then those results will be folded into the final result.
+ * If op is not commutative, then the result may differ from that of a fold applied to a
+ * non-distributed collection.
+ * @since 1.6.0
+ */
+ def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
+
+ /**
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * @since 1.6.0
+ */
+ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
+ val inputPlan = queryExecution.analyzed
+ val withGroupingKey = AppendColumn(func, inputPlan)
+ val executed = sqlContext.executePlan(withGroupingKey)
+
+ new GroupedDataset(
+ implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns),
+ implicitly[Encoder[T]].bind(inputPlan.output),
+ executed,
+ inputPlan.output,
+ withGroupingKey.newColumns)
+ }
+
+ /* ****************** *
+ * Typed Relational *
+ * ****************** */
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
+ *
+ * {{{
+ * val ds = Seq(1, 2, 3).toDS()
+ * val newDS = ds.select(e[Int]("value + 1"))
+ * }}}
+ * @since 1.6.0
+ */
+ def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
+ new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
+ }
+
+ // Codegen
+ // scalastyle:off
+
+ /** sbt scalaShell; println(Seq(1).toDS().genSelect) */
+ private def genSelect: String = {
+ (2 to 5).map { n =>
+ val types = (1 to n).map(i =>s"U$i").mkString(", ")
+ val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ")
+ val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ")
+ val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ")
+ s"""
+ |/**
+ | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ | * @since 1.6.0
+ | */
+ |def select[$types]($args): Dataset[($types)] = {
+ | implicit val te = new Tuple${n}Encoder($encoders)
+ | new Dataset[($types)](sqlContext,
+ | Project(
+ | $schema :: Nil,
+ | logicalPlan))
+ |}
+ |
+ """.stripMargin
+ }.mkString("\n")
+ }
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = {
+ implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder)
+ new Dataset[(U1, U2)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = {
+ implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder)
+ new Dataset[(U1, U2, U3)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = {
+ implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder)
+ new Dataset[(U1, U2, U3, U4)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil,
+ logicalPlan))
+ }
+
+
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = {
+ implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder)
+ new Dataset[(U1, U2, U3, U4, U5)](sqlContext,
+ Project(
+ Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil,
+ logicalPlan))
+ }
+
+ // scalastyle:on
+
+ /* **************** *
+ * Set operations *
+ * **************** */
+
+ /**
+ * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]].
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def distinct: Dataset[T] = withPlan(Distinct)
+
+ /**
+ * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also
+ * present in `other`.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def intersect(other: Dataset[T]): Dataset[T] =
+ withPlan[T](other)(Intersect)
+
+ /**
+ * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
+ * combined.
+ *
+ * Note that, this function is not a typical set union operation, in that it does not eliminate
+ * duplicate items. As such, it is analagous to `UNION ALL` in SQL.
+ * @since 1.6.0
+ */
+ def union(other: Dataset[T]): Dataset[T] =
+ withPlan[T](other)(Union)
+
+ /**
+ * Returns a new [[Dataset]] where any elements present in `other` have been removed.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ * @since 1.6.0
+ */
+ def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
+
+ /* ************************** *
+ * Gather to Driver Actions *
+ * ************************** */
+
+ /** Returns the first element in this [[Dataset]]. */
+ def first(): T = rdd.first()
+
+ /** Collects the elements to an Array. */
+ def collect(): Array[T] = rdd.collect()
+
+ /** Returns the first `num` elements of this [[Dataset]] as an Array. */
+ def take(num: Int): Array[T] = rdd.take(num)
+
+ /* ******************** *
+ * Internal Functions *
+ * ******************** */
+
+ private[sql] def logicalPlan = queryExecution.analyzed
+
+ private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)))
+
+ private[sql] def withPlan[R : Encoder](
+ other: Dataset[_])(
+ f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
+ new Dataset[R](
+ sqlContext,
+ sqlContext.executePlan(
+ f(logicalPlan, other.logicalPlan)))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
new file mode 100644
index 0000000000..17817cbcc5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -0,0 +1,30 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+/**
+ * A container for a [[DataFrame]], used for implicit conversions.
+ *
+ * @since 1.3.0
+ */
+private[sql] case class DatasetHolder[T](df: Dataset[T]) {
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDS(): Dataset[T] = df
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
new file mode 100644
index 0000000000..89a16dd8b0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.QueryExecution
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing
+ * [[Dataset]].
+ */
+class GroupedDataset[K, T] private[sql](
+ private val kEncoder: Encoder[K],
+ private val tEncoder: Encoder[T],
+ queryExecution: QueryExecution,
+ private val dataAttributes: Seq[Attribute],
+ private val groupingAttributes: Seq[Attribute]) extends Serializable {
+
+ private implicit def kEnc = kEncoder
+ private implicit def tEnc = tEncoder
+ private def logicalPlan = queryExecution.analyzed
+ private def sqlContext = queryExecution.sqlContext
+
+ /**
+ * Returns a [[Dataset]] that contains each unique key.
+ */
+ def keys: Dataset[K] = {
+ new Dataset[K](
+ sqlContext,
+ Distinct(
+ Project(groupingAttributes, logicalPlan)))
+ }
+
+ /**
+ * Applies the given function to each group of data. For each unique group, the function will
+ * be passed the group key and an iterator that contains all of the elements in the group. The
+ * function can return an iterator containing elements of an arbitrary type which will be returned
+ * as a new [[Dataset]].
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+ * constraints of their cluster.
+ */
+ def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
+ new Dataset[U](
+ sqlContext,
+ MapGroups(f, groupingAttributes, logicalPlan))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index a107639947..5e7198f974 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -21,6 +21,7 @@ import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
+
import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
@@ -33,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
@@ -487,6 +489,16 @@ class SQLContext private[sql](
DataFrame(this, logicalPlan)
}
+
+ def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
+ val enc = implicitly[Encoder[T]]
+ val attributes = enc.schema.toAttributes
+ val encoded = data.map(d => enc.toRow(d).copy())
+ val plan = new LocalRelation(attributes, encoded)
+
+ new Dataset[T](this, plan)
+ }
+
/**
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
* converted to Catalyst rows.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bf03c61088..af8474df0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -30,9 +34,19 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
*/
-private[sql] abstract class SQLImplicits {
+abstract class SQLImplicits {
protected def _sqlContext: SQLContext
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+
+ implicit def newIntEncoder: Encoder[Int] = new IntEncoder()
+ implicit def newLongEncoder: Encoder[Long] = new LongEncoder()
+ implicit def newStringEncoder: Encoder[String] = new StringEncoder()
+
+ implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = {
+ DatasetHolder(_sqlContext.createDataset(s))
+ }
+
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
new file mode 100644
index 0000000000..10742cf734
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression}
+
+object GroupedIterator {
+ def apply(
+ input: Iterator[InternalRow],
+ keyExpressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
+ if (input.hasNext) {
+ new GroupedIterator(input, keyExpressions, inputSchema)
+ } else {
+ Iterator.empty
+ }
+ }
+}
+
+/**
+ * Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to
+ * next will return a pair containing the current group and an iterator that will return all the
+ * elements of that group. Iterators for each group are lazily constructed by extracting rows
+ * from the input iterator. As such, full groups are never materialized by this class.
+ *
+ * Example input:
+ * {{{
+ * Input: [a, 1], [b, 2], [b, 3]
+ * Grouping: x#1
+ * InputSchema: x#1, y#2
+ * }}}
+ *
+ * Result:
+ * {{{
+ * First call to next(): ([a], Iterator([a, 1])
+ * Second call to next(): ([b], Iterator([b, 2], [b, 3])
+ * }}}
+ *
+ * Note, the class does not handle the case of an empty input for simplicity of implementation.
+ * Use the factory to construct a new instance.
+ *
+ * @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or
+ * it is possible for the same group to appear more than once.
+ * @param groupingExpressions The set of expressions used to do grouping. The result of evaluating
+ * these expressions will be returned as the first part of each call
+ * to `next()`.
+ * @param inputSchema The schema of the rows in the `input` iterator.
+ */
+class GroupedIterator private(
+ input: Iterator[InternalRow],
+ groupingExpressions: Seq[Expression],
+ inputSchema: Seq[Attribute])
+ extends Iterator[(InternalRow, Iterator[InternalRow])] {
+
+ /** Compares two input rows and returns 0 if they are in the same group. */
+ val sortOrder = groupingExpressions.map(SortOrder(_, Ascending))
+ val keyOrdering = GenerateOrdering.generate(sortOrder, inputSchema)
+
+ /** Creates a row containing only the key for a given input row. */
+ val keyProjection = GenerateUnsafeProjection.generate(groupingExpressions, inputSchema)
+
+ /**
+ * Holds null or the row that will be returned on next call to `next()` in the inner iterator.
+ */
+ var currentRow = input.next()
+
+ /** Holds a copy of an input row that is in the current group. */
+ var currentGroup = currentRow.copy()
+ var currentIterator: Iterator[InternalRow] = null
+ assert(keyOrdering.compare(currentGroup, currentRow) == 0)
+
+ // Return true if we already have the next iterator or fetching a new iterator is successful.
+ def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
+
+ def next(): (InternalRow, Iterator[InternalRow]) = {
+ assert(hasNext) // Ensure we have fetched the next iterator.
+ val ret = (keyProjection(currentGroup), currentIterator)
+ currentIterator = null
+ ret
+ }
+
+ def fetchNextGroupIterator(): Boolean = {
+ if (currentRow != null || input.hasNext) {
+ val inputIterator = new Iterator[InternalRow] {
+ // Return true if we have a row and it is in the current group, or if fetching a new row is
+ // successful.
+ def hasNext = {
+ (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) ||
+ fetchNextRowInGroup()
+ }
+
+ def fetchNextRowInGroup(): Boolean = {
+ if (currentRow != null || input.hasNext) {
+ currentRow = input.next()
+ if (keyOrdering.compare(currentGroup, currentRow) == 0) {
+ // The row is in the current group. Continue the inner iterator.
+ true
+ } else {
+ // We got a row, but its not in the right group. End this inner iterator and prepare
+ // for the next group.
+ currentIterator = null
+ currentGroup = currentRow.copy()
+ false
+ }
+ } else {
+ // There is no more input so we are done.
+ false
+ }
+ }
+
+ def next(): InternalRow = {
+ assert(hasNext) // Ensure we have fetched the next row.
+ val res = currentRow
+ currentRow = null
+ res
+ }
+ }
+ currentIterator = inputIterator
+ true
+ } else {
+ false
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 79bd1a4180..637deff4e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -372,6 +372,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
+
+ case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
+ execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
+ case logical.AppendColumn(f, tEnc, uEnc, newCol, child) =>
+ execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
+ case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
+ execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
+
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index dc38fe59fe..2bb3dba5bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.MutablePair
@@ -311,3 +313,80 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
protected override def doExecute(): RDD[InternalRow] = child.execute()
}
+
+/**
+ * Applies the given function to each input row and encodes the result.
+ */
+case class MapPartitions[T, U](
+ func: Iterator[T] => Iterator[U],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ output: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val tBoundEncoder = tEncoder.bind(child.output)
+ func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow)
+ }
+ }
+}
+
+/**
+ * Applies the given function to each input row, appending the encoded result at the end of the row.
+ */
+case class AppendColumns[T, U](
+ func: T => U,
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ newColumns: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output ++ newColumns
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val tBoundEncoder = tEncoder.bind(child.output)
+ val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
+ iter.map { row =>
+ val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row)))
+ combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
+ }
+ }
+ }
+}
+
+/**
+ * Groups the input rows together and calls the function with each group and an iterator containing
+ * all elements in the group. The result of this function is encoded and flattened before
+ * being output.
+ */
+case class MapGroups[K, T, U](
+ func: (K, Iterator[T]) => Iterator[U],
+ kEncoder: Encoder[K],
+ tEncoder: Encoder[T],
+ uEncoder: Encoder[U],
+ groupingAttributes: Seq[Attribute],
+ output: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode {
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val groupKeyEncoder = kEncoder.bind(groupingAttributes)
+
+ grouped.flatMap { case (key, rowIter) =>
+ val result = func(
+ groupKeyEncoder.fromRow(key),
+ rowIter.map(tEncoder.fromRow))
+ result.map(uEncoder.toRow)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
new file mode 100644
index 0000000000..32443557fb
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+case class IntClass(value: Int)
+
+class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("toDS") {
+ val data = Seq(1, 2, 3, 4, 5, 6)
+ checkAnswer(
+ data.toDS(),
+ data: _*)
+ }
+
+ test("as case class / collect") {
+ val ds = Seq(1, 2, 3).toDS().as[IntClass]
+ checkAnswer(
+ ds,
+ IntClass(1), IntClass(2), IntClass(3))
+
+ assert(ds.collect().head == IntClass(1))
+ }
+
+ test("map") {
+ val ds = Seq(1, 2, 3).toDS()
+ checkAnswer(
+ ds.map(_ + 1),
+ 2, 3, 4)
+ }
+
+ test("filter") {
+ val ds = Seq(1, 2, 3, 4).toDS()
+ checkAnswer(
+ ds.filter(_ % 2 == 0),
+ 2, 4)
+ }
+
+ test("foreach") {
+ val ds = Seq(1, 2, 3).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreach(acc +=)
+ assert(acc.value == 6)
+ }
+
+ test("foreachPartition") {
+ val ds = Seq(1, 2, 3).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreachPartition(_.foreach(acc +=))
+ assert(acc.value == 6)
+ }
+
+ test("reduce") {
+ val ds = Seq(1, 2, 3).toDS()
+ assert(ds.reduce(_ + _) == 6)
+ }
+
+ test("fold") {
+ val ds = Seq(1, 2, 3).toDS()
+ assert(ds.fold(0)(_ + _) == 6)
+ }
+
+ test("groupBy function, keys") {
+ val ds = Seq(1, 2, 3, 4, 5).toDS()
+ val grouped = ds.groupBy(_ % 2)
+ checkAnswer(
+ grouped.keys,
+ 0, 1)
+ }
+
+ test("groupBy function, mapGroups") {
+ val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
+ val grouped = ds.groupBy(_ % 2)
+ val agged = grouped.mapGroups { case (g, iter) =>
+ val name = if (g == 0) "even" else "odd"
+ Iterator((name, iter.size))
+ }
+
+ checkAnswer(
+ agged,
+ ("even", 5), ("odd", 6))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
new file mode 100644
index 0000000000..08496249c6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+case class ClassData(a: String, b: Int)
+
+class DatasetSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("toDS") {
+ val data = Seq(("a", 1) , ("b", 2), ("c", 3))
+ checkAnswer(
+ data.toDS(),
+ data: _*)
+ }
+
+ test("as case class / collect") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
+ checkAnswer(
+ ds,
+ ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
+ assert(ds.collect().head == ClassData("a", 1))
+ }
+
+ test("as case class - reordered fields by name") {
+ val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
+ assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)))
+ }
+
+ test("map") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.map(v => (v._1, v._2 + 1)),
+ ("a", 2), ("b", 3), ("c", 4))
+ }
+
+ test("select") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(expr("_2 + 1").as[Int]),
+ 2, 3, 4)
+ }
+
+ test("select 3") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.select(
+ expr("_1").as[String],
+ expr("_2").as[Int],
+ expr("_2 + 1").as[Int]),
+ ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+ }
+
+ test("filter") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ checkAnswer(
+ ds.filter(_._1 == "b"),
+ ("b", 2))
+ }
+
+ test("foreach") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreach(v => acc += v._2)
+ assert(acc.value == 6)
+ }
+
+ test("foreachPartition") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ val acc = sparkContext.accumulator(0)
+ ds.foreachPartition(_.foreach(v => acc += v._2))
+ assert(acc.value == 6)
+ }
+
+ test("reduce") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
+ }
+
+ test("fold") {
+ val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+ assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
+ }
+
+ test("groupBy function, keys") {
+ val ds = Seq(("a", 1), ("b", 1)).toDS()
+ val grouped = ds.groupBy(v => (1, v._2))
+ checkAnswer(
+ grouped.keys,
+ (1, 1))
+ }
+
+ test("groupBy function, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy(v => (v._1, "word"))
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g._1, iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index e3c5a42667..aba567512f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
import scala.collection.JavaConverters._
+import scala.reflect.runtime.universe._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
+import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
abstract class QueryTest extends PlanTest {
@@ -53,6 +55,12 @@ abstract class QueryTest extends PlanTest {
}
}
+ protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = {
+ checkAnswer(
+ ds.toDF(),
+ sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+ }
+
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param df the [[DataFrame]] to be executed