aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-10 17:00:17 -0800
committerYin Huai <yhuai@databricks.com>2016-03-10 17:00:17 -0800
commit1d542785b9949e7f92025e6754973a779cc37c52 (patch)
treeceda7492e40c9d9a9231a5011c91e30bf0b1f390 /sql/core/src/main/scala/org
parent27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff)
downloadspark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz
spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2
spark-1d542785b9949e7f92025e6754973a779cc37c52.zip
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request? This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`. Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`). There are several noticeable API changes related to those returning arrays: 1. `collect`/`take` - Old APIs in class `DataFrame`: ```scala def collect(): Array[Row] def take(n: Int): Array[Row] ``` - New APIs in class `Dataset[T]`: ```scala def collect(): Array[T] def take(n: Int): Array[T] def collectRows(): Array[Row] def takeRows(n: Int): Array[Row] ``` Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side. Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here). 1. `randomSplit` - Old APIs in class `DataFrame`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] def randomSplit(weights: Array[Double]): Array[DataFrame] ``` - New APIs in class `Dataset[T]`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] def randomSplit(weights: Array[Double]): Array[Dataset[T]] ``` Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one. 1. `groupBy` Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`. Other noticeable changes: 1. Dataset always do eager analysis now We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`. ## How was this patch tested? Existing tests do the work. ## TODO - [ ] Fix all tests - [ ] Re-enable MiMA check - [ ] Update ScalaDoc (`since`, `group`, and example code) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <liancheng@users.noreply.github.com> Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'sql/core/src/main/scala/org')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala532
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala794
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala1
15 files changed, 463 insertions, 901 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 24f61992d4..17a91975f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
+import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -26,30 +27,38 @@ import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable,
- QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ val qe = sqlContext.executePlan(logicalPlan)
+ qe.assertAnalyzed()
+ new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema))
+ }
+}
+
+private[sql] object Dataset {
+ def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = {
+ new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]])
}
}
@@ -112,28 +121,19 @@ private[sql] object DataFrame {
* @since 1.3.0
*/
@Experimental
-class DataFrame private[sql](
+class Dataset[T] private[sql](
@transient override val sqlContext: SQLContext,
- @DeveloperApi @transient override val queryExecution: QueryExecution)
+ @DeveloperApi @transient override val queryExecution: QueryExecution,
+ encoder: Encoder[T])
extends Queryable with Serializable {
+ queryExecution.assertAnalyzed()
+
// Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.
- /**
- * A constructor that automatically analyzes the logical plan.
- *
- * This reports error eagerly as the [[DataFrame]] is constructed, unless
- * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
- */
- def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
- this(sqlContext, {
- val qe = sqlContext.executePlan(logicalPlan)
- if (sqlContext.conf.dataFrameEagerAnalysis) {
- qe.assertAnalyzed() // This should force analysis and throw errors if there are any
- }
- qe
- })
+ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
+ this(sqlContext, sqlContext.executePlan(logicalPlan), encoder)
}
@transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
@@ -147,6 +147,26 @@ class DataFrame private[sql](
queryExecution.analyzed
}
+ /**
+ * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
+ * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+ * same object type (that will be possibly resolved to a different schema).
+ */
+ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder)
+ unresolvedTEncoder.validate(logicalPlan.output)
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
+ unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+ /**
+ * The encoder where the expressions used to construct an object from an input row have been
+ * bound to the ordinals of this [[Dataset]]'s output schema.
+ */
+ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
+
+ private implicit def classTag = unresolvedTEncoder.clsTag
+
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
@@ -173,7 +193,11 @@ class DataFrame private[sql](
// For array values, replace Seq and Array with square brackets
// For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map {
+ case r: Row => r
+ case tuple: Product => Row.fromTuple(tuple)
+ case o => Row(o)
+ }.map { row =>
row.toSeq.map { cell =>
val str = cell match {
case null => "null"
@@ -196,7 +220,7 @@ class DataFrame private[sql](
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = this
+ def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema))
/**
* :: Experimental ::
@@ -206,7 +230,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
+ def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
@@ -360,7 +384,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.1
*/
- def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
+ def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF())
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
@@ -372,7 +396,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)
+ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/**
* Cartesian join with another [[DataFrame]].
@@ -573,6 +597,62 @@ class DataFrame private[sql](
}
/**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ *
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
+ * @since 1.6.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val joined = sqlContext.executePlan(Join(left, right, joinType =
+ JoinType(joinType), Some(condition.expr)))
+ val leftOutput = joined.analyzed.output.take(left.output.length)
+ val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+
+ val leftData = this.unresolvedTEncoder match {
+ case e if e.flat => Alias(leftOutput.head, "_1")()
+ case _ => Alias(CreateStruct(leftOutput), "_1")()
+ }
+ val rightData = other.unresolvedTEncoder match {
+ case e if e.flat => Alias(rightOutput.head, "_2")()
+ case _ => Alias(CreateStruct(rightOutput), "_2")()
+ }
+
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
+ withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ joined.analyzed)
+ }
+ }
+
+ /**
+ * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
+ * where `condition` evaluates to true.
+ *
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @since 1.6.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ joinWith(other, condition, "inner")
+ }
+
+ /**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
@@ -581,7 +661,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = {
+ def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
}
@@ -594,7 +674,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortExprs: Column*): DataFrame = {
+ def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}
@@ -610,7 +690,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def sort(sortCol: String, sortCols: String*): DataFrame = {
+ def sort(sortCol: String, sortCols: String*): Dataset[T] = {
sort((sortCol +: sortCols).map(apply) : _*)
}
@@ -623,7 +703,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def sort(sortExprs: Column*): DataFrame = {
+ def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}
@@ -634,7 +714,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*)
+ def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@@ -643,7 +723,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
+ def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)
/**
* Selects column based on the column name and return it as a [[Column]].
@@ -672,7 +752,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def as(alias: String): DataFrame = withPlan {
+ def as(alias: String): Dataset[T] = withTypedPlan {
SubqueryAlias(alias, logicalPlan)
}
@@ -681,21 +761,21 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def as(alias: Symbol): DataFrame = as(alias.name)
+ def as(alias: Symbol): Dataset[T] = as(alias.name)
/**
* Returns a new [[DataFrame]] with an alias set. Same as `as`.
* @group dfops
* @since 1.6.0
*/
- def alias(alias: String): DataFrame = as(alias)
+ def alias(alias: String): Dataset[T] = as(alias)
/**
* (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`.
* @group dfops
* @since 1.6.0
*/
- def alias(alias: Symbol): DataFrame = as(alias)
+ def alias(alias: Symbol): Dataset[T] = as(alias)
/**
* Selects a set of column based expressions.
@@ -745,6 +825,80 @@ class DataFrame private[sql](
}
/**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
+ *
+ * {{{
+ * val ds = Seq(1, 2, 3).toDS()
+ * val newDS = ds.select(expr("value + 1").as[Int])
+ * }}}
+ * @since 1.6.0
+ */
+ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
+ new Dataset[U1](
+ sqlContext,
+ Project(
+ c1.withInputType(
+ boundTEncoder,
+ logicalPlan.output).named :: Nil,
+ logicalPlan),
+ implicitly[Encoder[U1]])
+ }
+
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
+ val encoders = columns.map(_.encoder)
+ val namedColumns =
+ columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
+ val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
+
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+ }
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4],
+ c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
+
+ /**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
@@ -754,7 +908,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def filter(condition: Column): DataFrame = withPlan {
+ def filter(condition: Column): Dataset[T] = withTypedPlan {
Filter(condition.expr, logicalPlan)
}
@@ -766,7 +920,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def filter(conditionExpr: String): DataFrame = {
+ def filter(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
@@ -780,7 +934,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def where(condition: Column): DataFrame = filter(condition)
+ def where(condition: Column): Dataset[T] = filter(condition)
/**
* Filters rows using the given SQL expression.
@@ -790,7 +944,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.5.0
*/
- def where(conditionExpr: String): DataFrame = {
+ def where(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
@@ -813,7 +967,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def groupBy(cols: Column*): GroupedData = {
- GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
+ GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType)
}
/**
@@ -836,7 +990,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def rollup(cols: Column*): GroupedData = {
- GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
+ GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType)
}
/**
@@ -858,7 +1012,7 @@ class DataFrame private[sql](
* @since 1.4.0
*/
@scala.annotation.varargs
- def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
+ def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType)
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -883,10 +1037,73 @@ class DataFrame private[sql](
@scala.annotation.varargs
def groupBy(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+ }
+
+ /**
+ * (Scala-specific)
+ * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: (T, T) => T): T = rdd.reduce(func)
+
+ /**
+ * (Java-specific)
+ * Reduces the elements of this Dataset using the specified binary function. The given `func`
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
+
+ /**
+ * (Scala-specific)
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+ * @since 1.6.0
+ */
+ def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = {
+ val inputPlan = logicalPlan
+ val withGroupingKey = AppendColumns(func, inputPlan)
+ val executed = sqlContext.executePlan(withGroupingKey)
+
+ new GroupedDataset(
+ encoderFor[K],
+ encoderFor[T],
+ executed,
+ inputPlan.output,
+ withGroupingKey.newColumns)
+ }
+
+ /**
+ * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def groupByKey(cols: Column*): GroupedDataset[Row, T] = {
+ val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
+ val withKey = Project(withKeyColumns, logicalPlan)
+ val executed = sqlContext.executePlan(withKey)
+
+ val dataAttributes = executed.analyzed.output.dropRight(cols.size)
+ val keyAttributes = executed.analyzed.output.takeRight(cols.size)
+
+ new GroupedDataset(
+ RowEncoder(keyAttributes.toStructType),
+ encoderFor[T],
+ executed,
+ dataAttributes,
+ keyAttributes)
}
/**
+ * (Java-specific)
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+ * @since 1.6.0
+ */
+ def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ groupByKey(func.call(_))(encoder)
+
+ /**
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
* so we can run aggregation on them.
* See [[GroupedData]] for all the available aggregate functions.
@@ -910,7 +1127,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def rollup(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType)
}
/**
@@ -937,7 +1154,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def cube(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType)
}
/**
@@ -997,7 +1214,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def limit(n: Int): DataFrame = withPlan {
+ def limit(n: Int): Dataset[T] = withTypedPlan {
Limit(Literal(n), logicalPlan)
}
@@ -1007,19 +1224,21 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def unionAll(other: DataFrame): DataFrame = withPlan {
+ def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan {
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
CombineUnions(Union(logicalPlan, other.logicalPlan))
}
+ def union(other: Dataset[T]): Dataset[T] = unionAll(other)
+
/**
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
* This is equivalent to `INTERSECT` in SQL.
* @group dfops
* @since 1.3.0
*/
- def intersect(other: DataFrame): DataFrame = withPlan {
+ def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan {
Intersect(logicalPlan, other.logicalPlan)
}
@@ -1029,10 +1248,12 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def except(other: DataFrame): DataFrame = withPlan {
+ def except(other: Dataset[T]): Dataset[T] = withTypedPlan {
Except(logicalPlan, other.logicalPlan)
}
+ def subtract(other: Dataset[T]): Dataset[T] = except(other)
+
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows.
*
@@ -1042,7 +1263,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan {
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
}
@@ -1054,7 +1275,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
sample(withReplacement, fraction, Utils.random.nextLong)
}
@@ -1066,7 +1287,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
@@ -1075,7 +1296,8 @@ class DataFrame private[sql](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
+ new Dataset[T](
+ sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
}.toArray
}
@@ -1086,7 +1308,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+ def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
randomSplit(weights, Utils.random.nextLong)
}
@@ -1097,7 +1319,7 @@ class DataFrame private[sql](
* @param seed Seed for sampling.
* @group dfops
*/
- private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+ private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
randomSplit(weights.toArray, seed)
}
@@ -1238,7 +1460,7 @@ class DataFrame private[sql](
}
select(columns : _*)
} else {
- this
+ toDF()
}
}
@@ -1264,7 +1486,7 @@ class DataFrame private[sql](
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
- this
+ toDF()
} else {
this.select(remainingCols: _*)
}
@@ -1297,7 +1519,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(): DataFrame = dropDuplicates(this.columns)
+ def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
/**
* (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only
@@ -1306,7 +1528,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan {
+ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val groupCols = colNames.map(resolve)
val groupColExprIds = groupCols.map(_.exprId)
val aggCols = logicalPlan.output.map { attr =>
@@ -1326,7 +1548,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq)
+ def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
/**
* Computes statistics for numeric columns, including count, mean, stddev, min, and max.
@@ -1396,7 +1618,7 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df =>
+ def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df =>
df.collect(needCallback = false)
}
@@ -1405,14 +1627,14 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def head(): Row = head(1).head
+ def head(): T = head(1).head
/**
* Returns the first row. Alias for head().
* @group action
* @since 1.3.0
*/
- def first(): Row = head()
+ def first(): T = head()
/**
* Concise syntax for chaining custom transformations.
@@ -1425,27 +1647,113 @@ class DataFrame private[sql](
* }}}
* @since 1.6.0
*/
- def transform[U](t: DataFrame => DataFrame): DataFrame = t(this)
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ map(t => func.call(t))(encoder)
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+ * @since 1.6.0
+ */
+ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+ new Dataset[U](
+ sqlContext,
+ MapPartitions[T, U](func, logicalPlan),
+ implicitly[Encoder[U]])
+ }
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+ * @since 1.6.0
+ */
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
+ mapPartitions(func)(encoder)
+ }
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
+ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+ mapPartitions(_.flatMap(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
+ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (T) => Iterator[U] = x => f.call(x).asScala
+ flatMap(func)(encoder)
+ }
/**
* Applies a function `f` to all rows.
* @group rdd
* @since 1.3.0
*/
- def foreach(f: Row => Unit): Unit = withNewExecutionId {
+ def foreach(f: T => Unit): Unit = withNewExecutionId {
rdd.foreach(f)
}
/**
+ * (Java-specific)
+ * Runs `func` on each element of this [[Dataset]].
+ * @since 1.6.0
+ */
+ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
+
+ /**
* Applies a function f to each partition of this [[DataFrame]].
* @group rdd
* @since 1.3.0
*/
- def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
+ def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId {
rdd.foreachPartition(f)
}
/**
+ * (Java-specific)
+ * Runs `func` on each partition of this [[Dataset]].
+ * @since 1.6.0
+ */
+ def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
+ foreachPartition(it => func.call(it.asJava))
+
+ /**
* Returns the first `n` rows in the [[DataFrame]].
*
* Running take requires moving data into the application's driver process, and doing so with
@@ -1454,7 +1762,11 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def take(n: Int): Array[Row] = head(n)
+ def take(n: Int): Array[T] = head(n)
+
+ def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds =>
+ ds.collectRows(needCallback = false)
+ }
/**
* Returns the first `n` rows in the [[DataFrame]] as a list.
@@ -1465,7 +1777,7 @@ class DataFrame private[sql](
* @group action
* @since 1.6.0
*/
- def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
+ def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*)
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
@@ -1478,7 +1790,9 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collect(): Array[Row] = collect(needCallback = true)
+ def collect(): Array[T] = collect(needCallback = true)
+
+ def collectRows(): Array[Row] = collectRows(needCallback = true)
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
@@ -1489,19 +1803,32 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
- java.util.Arrays.asList(rdd.collect() : _*)
+ val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ java.util.Arrays.asList(values : _*)
}
}
- private def collect(needCallback: Boolean): Array[Row] = {
+ private def collect(needCallback: Boolean): Array[T] = {
+ def execute(): Array[T] = withNewExecutionId {
+ queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ }
+
+ if (needCallback) {
+ withCallback("collect", toDF())(_ => execute())
+ } else {
+ execute()
+ }
+ }
+
+ private def collectRows(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
}
if (needCallback) {
- withCallback("collect", this)(_ => execute())
+ withCallback("collect", toDF())(_ => execute())
} else {
execute()
}
@@ -1521,7 +1848,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def repartition(numPartitions: Int): DataFrame = withPlan {
+ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = true, logicalPlan)
}
@@ -1535,7 +1862,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan {
+ def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
}
@@ -1549,7 +1876,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def repartition(partitionExprs: Column*): DataFrame = withPlan {
+ def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
}
@@ -1561,7 +1888,7 @@ class DataFrame private[sql](
* @group rdd
* @since 1.4.0
*/
- def coalesce(numPartitions: Int): DataFrame = withPlan {
+ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
}
@@ -1571,7 +1898,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def distinct(): DataFrame = dropDuplicates()
+ def distinct(): Dataset[T] = dropDuplicates()
/**
* Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
@@ -1632,12 +1959,11 @@ class DataFrame private[sql](
* @group rdd
* @since 1.3.0
*/
- lazy val rdd: RDD[Row] = {
+ lazy val rdd: RDD[T] = {
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
queryExecution.toRdd.mapPartitions { rows =>
- val converter = CatalystTypeConverters.createToScalaConverter(schema)
- rows.map(converter(_).asInstanceOf[Row])
+ rows.map(boundTEncoder.fromRow)
}
}
@@ -1646,14 +1972,14 @@ class DataFrame private[sql](
* @group rdd
* @since 1.3.0
*/
- def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
+ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD()
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
* @group rdd
* @since 1.3.0
*/
- def javaRDD: JavaRDD[Row] = toJavaRDD
+ def javaRDD: JavaRDD[T] = toJavaRDD
/**
* Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this
@@ -1663,7 +1989,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def registerTempTable(tableName: String): Unit = {
- sqlContext.registerDataFrameAsTable(this, tableName)
+ sqlContext.registerDataFrameAsTable(toDF(), tableName)
}
/**
@@ -1674,7 +2000,7 @@ class DataFrame private[sql](
* @since 1.4.0
*/
@Experimental
- def write: DataFrameWriter = new DataFrameWriter(this)
+ def write: DataFrameWriter = new DataFrameWriter(toDF())
/**
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
@@ -1745,7 +2071,7 @@ class DataFrame private[sql](
* Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
* an execution.
*/
- private[sql] def withNewExecutionId[T](body: => T): T = {
+ private[sql] def withNewExecutionId[U](body: => U): U = {
SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
}
@@ -1753,7 +2079,7 @@ class DataFrame private[sql](
* Wrap a DataFrame action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
- private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
+ private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = {
try {
df.queryExecution.executedPlan.foreach { plan =>
plan.resetMetrics()
@@ -1770,7 +2096,24 @@ class DataFrame private[sql](
}
}
- private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
+ private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = {
+ try {
+ ds.queryExecution.executedPlan.foreach { plan =>
+ plan.resetMetrics()
+ }
+ val start = System.nanoTime()
+ val result = action(ds)
+ val end = System.nanoTime()
+ sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start)
+ result
+ } catch {
+ case e: Exception =>
+ sqlContext.listenerManager.onFailure(name, ds.queryExecution, e)
+ throw e
+ }
+ }
+
+ private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
@@ -1779,14 +2122,23 @@ class DataFrame private[sql](
SortOrder(expr, Ascending)
}
}
- withPlan {
+ withTypedPlan {
Sort(sortOrder, global = global, logicalPlan)
}
}
/** A convenient function to wrap a logical plan and produce a DataFrame. */
@inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ DataFrame(sqlContext, logicalPlan)
+ }
+
+ /** A convenient function to wrap a logical plan and produce a DataFrame. */
+ @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
+ new Dataset[T](sqlContext, logicalPlan, encoder)
}
+ private[sql] def withTypedPlan[R](
+ other: Dataset[_], encoder: Encoder[R])(
+ f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 509b29956f..822702429d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -345,7 +345,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
InferSchema.infer(jsonRDD, sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
}
- new DataFrame(
+ DataFrame(
sqlContext,
LogicalRDD(
schema.toAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
deleted file mode 100644
index daddf6e0c5..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ /dev/null
@@ -1,794 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.function._
-import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{Queryable, QueryExecution}
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-/**
- * :: Experimental ::
- * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel
- * using functional or relational operations.
- *
- * A [[Dataset]] differs from an [[RDD]] in the following ways:
- * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored
- * in the encoded form. This representation allows for additional logical operations and
- * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to
- * an object.
- * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be
- * used to serialize the object into a binary format. Encoders are also capable of mapping the
- * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime
- * reflection based serialization. Operations that change the type of object stored in the
- * dataset also need an encoder for the new type.
- *
- * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific
- * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into
- * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed
- * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`.
- *
- * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However,
- * making this change to the class hierarchy would break the function signatures for the existing
- * functional operations (map, flatMap, etc). As such, this class should be considered a preview
- * of the final API. Changes will be made to the interface after Spark 1.6.
- *
- * @since 1.6.0
- */
-@Experimental
-class Dataset[T] private[sql](
- @transient override val sqlContext: SQLContext,
- @transient override val queryExecution: QueryExecution,
- tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
-
- /**
- * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
- * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
- * same object type (that will be possibly resolved to a different schema).
- */
- private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
- unresolvedTEncoder.validate(logicalPlan.output)
-
- /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
- private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
-
- /**
- * The encoder where the expressions used to construct an object from an input row have been
- * bound to the ordinals of this [[Dataset]]'s output schema.
- */
- private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
-
- private implicit def classTag = unresolvedTEncoder.clsTag
-
- private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
-
- /**
- * Returns the schema of the encoded form of the objects in this [[Dataset]].
- * @since 1.6.0
- */
- override def schema: StructType = resolvedTEncoder.schema
-
- /**
- * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
- * @since 1.6.0
- */
- override def printSchema(): Unit = toDF().printSchema()
-
- /**
- * Prints the plans (logical and physical) to the console for debugging purposes.
- * @since 1.6.0
- */
- override def explain(extended: Boolean): Unit = toDF().explain(extended)
-
- /**
- * Prints the physical plan to the console for debugging purposes.
- * @since 1.6.0
- */
- override def explain(): Unit = toDF().explain()
-
- /* ************* *
- * Conversions *
- * ************* */
-
- /**
- * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
- * method used to map columns depend on the type of `U`:
- * - When `U` is a class, fields for the class will be mapped to columns of the same name
- * (case sensitivity is determined by `spark.sql.caseSensitive`)
- * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will
- * be assigned to `_1`).
- * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the
- * [[DataFrame]] will be used.
- *
- * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select`
- * along with `alias` or `as` to rearrange or rename as required.
- * @since 1.6.0
- */
- def as[U : Encoder]: Dataset[U] = {
- new Dataset(sqlContext, queryExecution, encoderFor[U])
- }
-
- /**
- * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
- * the same name after two Datasets have been joined.
- * @since 1.6.0
- */
- def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))
-
- /**
- * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
- * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
- * objects that allow fields to be accessed by ordinal or name.
- */
- // This is declared with parentheses to prevent the Scala compiler from treating
- // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
-
- /**
- * Returns this [[Dataset]].
- * @since 1.6.0
- */
- // This is declared with parentheses to prevent the Scala compiler from treating
- // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
- def toDS(): Dataset[T] = this
-
- /**
- * Converts this [[Dataset]] to an [[RDD]].
- * @since 1.6.0
- */
- def rdd: RDD[T] = {
- queryExecution.toRdd.mapPartitions { iter =>
- iter.map(boundTEncoder.fromRow)
- }
- }
-
- /**
- * Returns the number of elements in the [[Dataset]].
- * @since 1.6.0
- */
- def count(): Long = toDF().count()
-
- /**
- * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right. For example:
- * {{{
- * year month AVG('Adj Close) MAX('Adj Close)
- * 1980 12 0.503218 0.595103
- * 1981 01 0.523289 0.570307
- * 1982 02 0.436504 0.475256
- * 1983 03 0.410516 0.442194
- * 1984 04 0.450090 0.483521
- * }}}
- * @param numRows Number of rows to show
- *
- * @since 1.6.0
- */
- def show(numRows: Int): Unit = show(numRows, truncate = true)
-
- /**
- * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right.
- *
- * @since 1.6.0
- */
- def show(): Unit = show(20)
-
- /**
- * Displays the top 20 rows of [[Dataset]] in a tabular form.
- *
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
- *
- * @since 1.6.0
- */
- def show(truncate: Boolean): Unit = show(20, truncate)
-
- /**
- * Displays the [[Dataset]] in a tabular form. For example:
- * {{{
- * year month AVG('Adj Close) MAX('Adj Close)
- * 1980 12 0.503218 0.595103
- * 1981 01 0.523289 0.570307
- * 1982 02 0.436504 0.475256
- * 1983 03 0.410516 0.442194
- * 1984 04 0.450090 0.483521
- * }}}
- * @param numRows Number of rows to show
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
- *
- * @since 1.6.0
- */
- // scalastyle:off println
- def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate))
- // scalastyle:on println
-
- /**
- * Compose the string representing rows for output
- * @param _numRows Number of rows to show
- * @param truncate Whether truncate long strings and align cells right
- */
- override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
- val numRows = _numRows.max(0)
- val takeResult = take(numRows + 1)
- val hasMoreData = takeResult.length > numRows
- val data = takeResult.take(numRows)
-
- // For array values, replace Seq and Array with square brackets
- // For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map {
- case r: Row => r
- case tuple: Product => Row.fromTuple(tuple)
- case o => Row(o)
- } map { row =>
- row.toSeq.map { cell =>
- val str = cell match {
- case null => "null"
- case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
- case array: Array[_] => array.mkString("[", ", ", "]")
- case seq: Seq[_] => seq.mkString("[", ", ", "]")
- case _ => cell.toString
- }
- if (truncate && str.length > 20) str.substring(0, 17) + "..." else str
- }: Seq[String]
- })
-
- formatString ( rows, numRows, hasMoreData, truncate )
- }
-
- /**
- * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
- * @since 1.6.0
- */
- def repartition(numPartitions: Int): Dataset[T] = withPlan {
- Repartition(numPartitions, shuffle = true, _)
- }
-
- /**
- * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
- * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
- * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
- * the 100 new partitions will claim 10 of the current partitions.
- * @since 1.6.0
- */
- def coalesce(numPartitions: Int): Dataset[T] = withPlan {
- Repartition(numPartitions, shuffle = false, _)
- }
-
- /* *********************** *
- * Functional Operations *
- * *********************** */
-
- /**
- * Concise syntax for chaining custom transformations.
- * {{{
- * def featurize(ds: Dataset[T]) = ...
- *
- * dataset
- * .transform(featurize)
- * .transform(...)
- * }}}
- * @since 1.6.0
- */
- def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
- * @since 1.6.0
- */
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
- * @since 1.6.0
- */
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
- * @since 1.6.0
- */
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
- * @since 1.6.0
- */
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
- * @since 1.6.0
- */
- def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
- new Dataset[U](
- sqlContext,
- MapPartitions[T, U](func, logicalPlan))
- }
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
- * @since 1.6.0
- */
- def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
- mapPartitions(func)(encoder)
- }
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
- * and then flattening the results.
- * @since 1.6.0
- */
- def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
- mapPartitions(_.flatMap(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
- * and then flattening the results.
- * @since 1.6.0
- */
- def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- val func: (T) => Iterator[U] = x => f.call(x).asScala
- flatMap(func)(encoder)
- }
-
- /* ************** *
- * Side effects *
- * ************** */
-
- /**
- * (Scala-specific)
- * Runs `func` on each element of this [[Dataset]].
- * @since 1.6.0
- */
- def foreach(func: T => Unit): Unit = rdd.foreach(func)
-
- /**
- * (Java-specific)
- * Runs `func` on each element of this [[Dataset]].
- * @since 1.6.0
- */
- def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
-
- /**
- * (Scala-specific)
- * Runs `func` on each partition of this [[Dataset]].
- * @since 1.6.0
- */
- def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
-
- /**
- * (Java-specific)
- * Runs `func` on each partition of this [[Dataset]].
- * @since 1.6.0
- */
- def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
- foreachPartition(it => func.call(it.asJava))
-
- /* ************* *
- * Aggregation *
- * ************* */
-
- /**
- * (Scala-specific)
- * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
- * @since 1.6.0
- */
- def reduce(func: (T, T) => T): T = rdd.reduce(func)
-
- /**
- * (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
- * @since 1.6.0
- */
- def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
-
- /**
- * (Scala-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
- * @since 1.6.0
- */
- def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
- val inputPlan = logicalPlan
- val withGroupingKey = AppendColumns(func, inputPlan)
- val executed = sqlContext.executePlan(withGroupingKey)
-
- new GroupedDataset(
- encoderFor[K],
- encoderFor[T],
- executed,
- inputPlan.output,
- withGroupingKey.newColumns)
- }
-
- /**
- * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def groupBy(cols: Column*): GroupedDataset[Row, T] = {
- val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
- val withKey = Project(withKeyColumns, logicalPlan)
- val executed = sqlContext.executePlan(withKey)
-
- val dataAttributes = executed.analyzed.output.dropRight(cols.size)
- val keyAttributes = executed.analyzed.output.takeRight(cols.size)
-
- new GroupedDataset(
- RowEncoder(keyAttributes.toStructType),
- encoderFor[T],
- executed,
- dataAttributes,
- keyAttributes)
- }
-
- /**
- * (Java-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
- * @since 1.6.0
- */
- def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
- groupBy(func.call(_))(encoder)
-
- /* ****************** *
- * Typed Relational *
- * ****************** */
-
- /**
- * Returns a new [[DataFrame]] by selecting a set of column based expressions.
- * {{{
- * df.select($"colA", $"colB" + 1)
- * }}}
- * @since 1.6.0
- */
- // Copied from Dataframe to make sure we don't have invalid overloads.
- @scala.annotation.varargs
- protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
- *
- * {{{
- * val ds = Seq(1, 2, 3).toDS()
- * val newDS = ds.select(expr("value + 1").as[Int])
- * }}}
- * @since 1.6.0
- */
- def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](
- sqlContext,
- Project(
- c1.withInputType(
- boundTEncoder,
- logicalPlan.output).named :: Nil,
- logicalPlan))
- }
-
- /**
- * Internal helper function for building typed selects that return tuples. For simplicity and
- * code reuse, we do this without the help of the type system and then use helper functions
- * that cast appropriately for the user facing interface.
- */
- protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(_.encoder)
- val namedColumns =
- columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
- val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
-
- new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
- }
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
- selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
- selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3, U4](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3],
- c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
- selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3, U4, U5](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3],
- c4: TypedColumn[T, U4],
- c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
- selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
-
- /**
- * Returns a new [[Dataset]] by sampling a fraction of records.
- * @since 1.6.0
- */
- def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
- withPlan(Sample(0.0, fraction, withReplacement, seed, _)())
-
- /**
- * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.
- * @since 1.6.0
- */
- def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = {
- sample(withReplacement, fraction, Utils.random.nextLong)
- }
-
- /* **************** *
- * Set operations *
- * **************** */
-
- /**
- * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]].
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def distinct: Dataset[T] = withPlan(Distinct)
-
- /**
- * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also
- * present in `other`.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
-
- /**
- * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
- * combined.
- *
- * Note that, this function is not a typical set union operation, in that it does not eliminate
- * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
- * @since 1.6.0
- */
- def union(other: Dataset[T]): Dataset[T] = withPlan[T](other) { (left, right) =>
- // This breaks caching, but it's usually ok because it addresses a very specific use case:
- // using union to union many files or partitions.
- CombineUnions(Union(left, right))
- }
-
- /**
- * Returns a new [[Dataset]] where any elements present in `other` have been removed.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
-
- /* ****** *
- * Joins *
- * ****** */
-
- /**
- * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
- * true.
- *
- * This is similar to the relation `join` function with one important difference in the
- * result schema. Since `joinWith` preserves objects present on either side of the join, the
- * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
- *
- * This type of join can be useful both for preserving type-safety with the original object
- * types as well as working with relational data where either side of the join has column
- * names in common.
- *
- * @param other Right side of the join.
- * @param condition Join expression.
- * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
- * @since 1.6.0
- */
- def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
- val left = this.logicalPlan
- val right = other.logicalPlan
-
- val joined = sqlContext.executePlan(Join(left, right, joinType =
- JoinType(joinType), Some(condition.expr)))
- val leftOutput = joined.analyzed.output.take(left.output.length)
- val rightOutput = joined.analyzed.output.takeRight(right.output.length)
-
- val leftData = this.unresolvedTEncoder match {
- case e if e.flat => Alias(leftOutput.head, "_1")()
- case _ => Alias(CreateStruct(leftOutput), "_1")()
- }
- val rightData = other.unresolvedTEncoder match {
- case e if e.flat => Alias(rightOutput.head, "_2")()
- case _ => Alias(CreateStruct(rightOutput), "_2")()
- }
-
- implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withPlan[(T, U)](other) { (left, right) =>
- Project(
- leftData :: rightData :: Nil,
- joined.analyzed)
- }
- }
-
- /**
- * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
- * where `condition` evaluates to true.
- *
- * @param other Right side of the join.
- * @param condition Join expression.
- * @since 1.6.0
- */
- def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
- joinWith(other, condition, "inner")
- }
-
- /* ************************** *
- * Gather to Driver Actions *
- * ************************** */
-
- /**
- * Returns the first element in this [[Dataset]].
- * @since 1.6.0
- */
- def first(): T = take(1).head
-
- /**
- * Returns an array that contains all the elements in this [[Dataset]].
- *
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
- *
- * For Java API, use [[collectAsList]].
- * @since 1.6.0
- */
- def collect(): Array[T] = {
- // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
- // to convert the rows into objects of type T.
- queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
- }
-
- /**
- * Returns an array that contains all the elements in this [[Dataset]].
- *
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
- *
- * For Java API, use [[collectAsList]].
- * @since 1.6.0
- */
- def collectAsList(): java.util.List[T] = collect().toSeq.asJava
-
- /**
- * Returns the first `num` elements of this [[Dataset]] as an array.
- *
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `num` can crash the driver process with OutOfMemoryError.
- * @since 1.6.0
- */
- def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
-
- /**
- * Returns the first `num` elements of this [[Dataset]] as an array.
- *
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `num` can crash the driver process with OutOfMemoryError.
- * @since 1.6.0
- */
- def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
-
- /**
- * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
- * @since 1.6.0
- */
- def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
- this
- }
-
- /**
- * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
- * @since 1.6.0
- */
- def cache(): this.type = persist()
-
- /**
- * Persist this [[Dataset]] with the given storage level.
- * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
- * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
- * `MEMORY_AND_DISK_2`, etc.
- * @group basic
- * @since 1.6.0
- */
- def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
- this
- }
-
- /**
- * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
- * @param blocking Whether to block until all blocks are deleted.
- * @since 1.6.0
- */
- def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
- this
- }
-
- /**
- * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
- * @since 1.6.0
- */
- def unpersist(): this.type = unpersist(blocking = false)
-
- /* ******************** *
- * Internal Functions *
- * ******************** */
-
- private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
-
- private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
-
- private[sql] def withPlan[R : Encoder](
- other: Dataset[_])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index a7258d742a..2a0f77349a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.NumericType
/**
* :: Experimental ::
- * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
+ * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
*
* The main method is the agg function, which has multiple variants. This class also contains
* convenience some first order statistics such as mean, sum for convenience.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index cd8ed472ec..1639cc8db6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -64,7 +64,7 @@ class GroupedDataset[K, V] private[sql](
private def groupedData =
new GroupedData(
- new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
+ DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
/**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
@@ -86,7 +86,7 @@ class GroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def keys: Dataset[K] = {
- new Dataset[K](
+ Dataset[K](
sqlContext,
Distinct(
Project(groupingAttributes, logicalPlan)))
@@ -111,7 +111,7 @@ class GroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
- new Dataset[U](
+ Dataset[U](
sqlContext,
MapGroups(
f,
@@ -308,7 +308,7 @@ class GroupedDataset[K, V] private[sql](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit val uEncoder = other.unresolvedVEncoder
- new Dataset[R](
+ Dataset[R](
sqlContext,
CoGroup(
f,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c742bf2f89..54dbd6bda5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -464,7 +464,7 @@ class SQLContext private[sql](
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
- new Dataset[T](this, plan)
+ Dataset[T](this, plan)
}
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
@@ -473,7 +473,7 @@ class SQLContext private[sql](
val encoded = data.map(d => enc.toRow(d))
val plan = LogicalRDD(attributes, encoded)(self)
- new Dataset[T](this, plan)
+ Dataset[T](this, plan)
}
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 16c4095db7..e23d5e1261 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -126,6 +126,7 @@ abstract class SQLImplicits {
/**
* Creates a [[Dataset]] from an RDD.
+ *
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 8616fe3170..19ab3ea132 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
@@ -31,7 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
- def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed)
+ def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch {
+ case e: AnalysisException =>
+ throw new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
+ }
lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index e048ee1441..60ec67c8f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -154,7 +154,7 @@ case class DataSource(
}
def dataFrameBuilder(files: Array[String]): DataFrame = {
- new DataFrame(
+ DataFrame(
sqlContext,
LogicalRelation(
DataSource(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index a191759813..0dc34814fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging {
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
- new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
+ DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 26e4eda542..daa065e5cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging {
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+ DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index bc7c520930..7d7c51b158 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -211,7 +211,7 @@ class StreamExecution(
// Construct the batch and send it to the sink.
val batchOffset = streamProgress.toCompositeOffset(sources)
- val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
+ val nextBatch = new Batch(batchOffset, DataFrame(sqlContext, newPlan))
sink.addBatch(nextBatch)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 8124df15af..3b764c5558 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -55,11 +55,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def schema: StructType = encoder.schema
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
- new Dataset(sqlContext, logicalPlan)
+ Dataset(sqlContext, logicalPlan)
}
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ DataFrame(sqlContext, logicalPlan)
}
def addData(data: A*): Offset = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 6eea924517..844f3051fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -46,7 +46,6 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
* @tparam I The input type for the aggregation.
* @tparam B The type of the intermediate value of the reduction.
* @tparam O The type of the final output result.
- *
* @since 1.6.0
*/
abstract class Aggregator[-I, B, O] extends Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index bd73a36fd4..97e35bb104 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -42,4 +42,5 @@ package object sql {
@DeveloperApi
type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+ type DataFrame = Dataset[Row]
}