diff options
Diffstat (limited to 'sql/core')
26 files changed, 657 insertions, 1106 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 24f61992d4..17a91975f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -26,30 +27,38 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.function._ import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression -import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, - QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) + val qe = sqlContext.executePlan(logicalPlan) + qe.assertAnalyzed() + new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema)) + } +} + +private[sql] object Dataset { + def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = { + new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]]) } } @@ -112,28 +121,19 @@ private[sql] object DataFrame { * @since 1.3.0 */ @Experimental -class DataFrame private[sql]( +class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, - @DeveloperApi @transient override val queryExecution: QueryExecution) + @DeveloperApi @transient override val queryExecution: QueryExecution, + encoder: Encoder[T]) extends Queryable with Serializable { + queryExecution.assertAnalyzed() + // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. - /** - * A constructor that automatically analyzes the logical plan. - * - * This reports error eagerly as the [[DataFrame]] is constructed, unless - * [[SQLConf.dataFrameEagerAnalysis]] is turned off. - */ - def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = { - this(sqlContext, { - val qe = sqlContext.executePlan(logicalPlan) - if (sqlContext.conf.dataFrameEagerAnalysis) { - qe.assertAnalyzed() // This should force analysis and throw errors if there are any - } - qe - }) + def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { + this(sqlContext, sqlContext.executePlan(logicalPlan), encoder) } @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match { @@ -147,6 +147,26 @@ class DataFrame private[sql]( queryExecution.analyzed } + /** + * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is + * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the + * same object type (that will be possibly resolved to a different schema). + */ + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) + unresolvedTEncoder.validate(logicalPlan.output) + + /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + + /** + * The encoder where the expressions used to construct an object from an input row have been + * bound to the ordinals of this [[Dataset]]'s output schema. + */ + private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + + private implicit def classTag = unresolvedTEncoder.clsTag + protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -173,7 +193,11 @@ class DataFrame private[sql]( // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { + case r: Row => r + case tuple: Product => Row.fromTuple(tuple) + case o => Row(o) + }.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -196,7 +220,7 @@ class DataFrame private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = this + def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema)) /** * :: Experimental :: @@ -206,7 +230,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) + def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan) /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion @@ -360,7 +384,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.1 */ - def na: DataFrameNaFunctions = new DataFrameNaFunctions(this) + def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF()) /** * Returns a [[DataFrameStatFunctions]] for working statistic functions support. @@ -372,7 +396,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this) + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) /** * Cartesian join with another [[DataFrame]]. @@ -573,6 +597,62 @@ class DataFrame private[sql]( } /** + * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * true. + * + * This is similar to the relation `join` function with one important difference in the + * result schema. Since `joinWith` preserves objects present on either side of the join, the + * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * + * This type of join can be useful both for preserving type-safety with the original object + * types as well as working with relational data where either side of the join has column + * names in common. + * + * @param other Right side of the join. + * @param condition Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { + val left = this.logicalPlan + val right = other.logicalPlan + + val joined = sqlContext.executePlan(Join(left, right, joinType = + JoinType(joinType), Some(condition.expr))) + val leftOutput = joined.analyzed.output.take(left.output.length) + val rightOutput = joined.analyzed.output.takeRight(right.output.length) + + val leftData = this.unresolvedTEncoder match { + case e if e.flat => Alias(leftOutput.head, "_1")() + case _ => Alias(CreateStruct(leftOutput), "_1")() + } + val rightData = other.unresolvedTEncoder match { + case e if e.flat => Alias(rightOutput.head, "_2")() + case _ => Alias(CreateStruct(rightOutput), "_2")() + } + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + Project( + leftData :: rightData :: Nil, + joined.analyzed) + } + } + + /** + * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * where `condition` evaluates to true. + * + * @param other Right side of the join. + * @param condition Join expression. + * @since 1.6.0 + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + joinWith(other, condition, "inner") + } + + /** * Returns a new [[DataFrame]] with each partition sorted by the given expressions. * * This is the same operation as "SORT BY" in SQL (Hive QL). @@ -581,7 +661,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = { + def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*) } @@ -594,7 +674,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortExprs: Column*): DataFrame = { + def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { sortInternal(global = false, sortExprs) } @@ -610,7 +690,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DataFrame = { + def sort(sortCol: String, sortCols: String*): Dataset[T] = { sort((sortCol +: sortCols).map(apply) : _*) } @@ -623,7 +703,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def sort(sortExprs: Column*): DataFrame = { + def sort(sortExprs: Column*): Dataset[T] = { sortInternal(global = true, sortExprs) } @@ -634,7 +714,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) + def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -643,7 +723,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) + def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -672,7 +752,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def as(alias: String): DataFrame = withPlan { + def as(alias: String): Dataset[T] = withTypedPlan { SubqueryAlias(alias, logicalPlan) } @@ -681,21 +761,21 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def as(alias: Symbol): DataFrame = as(alias.name) + def as(alias: Symbol): Dataset[T] = as(alias.name) /** * Returns a new [[DataFrame]] with an alias set. Same as `as`. * @group dfops * @since 1.6.0 */ - def alias(alias: String): DataFrame = as(alias) + def alias(alias: String): Dataset[T] = as(alias) /** * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`. * @group dfops * @since 1.6.0 */ - def alias(alias: Symbol): DataFrame = as(alias) + def alias(alias: Symbol): Dataset[T] = as(alias) /** * Selects a set of column based expressions. @@ -745,6 +825,80 @@ class DataFrame private[sql]( } /** + * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(expr("value + 1").as[Int]) + * }}} + * @since 1.6.0 + */ + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + boundTEncoder, + logicalPlan.output).named :: Nil, + logicalPlan), + implicitly[Encoder[U1]]) + } + + /** + * Internal helper function for building typed selects that return tuples. For simplicity and + * code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) + + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] + + /** * Filters rows using the given condition. * {{{ * // The following are equivalent: @@ -754,7 +908,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def filter(condition: Column): DataFrame = withPlan { + def filter(condition: Column): Dataset[T] = withTypedPlan { Filter(condition.expr, logicalPlan) } @@ -766,7 +920,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def filter(conditionExpr: String): DataFrame = { + def filter(conditionExpr: String): Dataset[T] = { filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } @@ -780,7 +934,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def where(condition: Column): DataFrame = filter(condition) + def where(condition: Column): Dataset[T] = filter(condition) /** * Filters rows using the given SQL expression. @@ -790,7 +944,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.5.0 */ - def where(conditionExpr: String): DataFrame = { + def where(conditionExpr: String): Dataset[T] = { filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } @@ -813,7 +967,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def groupBy(cols: Column*): GroupedData = { - GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) + GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType) } /** @@ -836,7 +990,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def rollup(cols: Column*): GroupedData = { - GroupedData(this, cols.map(_.expr), GroupedData.RollupType) + GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType) } /** @@ -858,7 +1012,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType) + def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -883,10 +1037,73 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + } + + /** + * (Scala-specific) + * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: (T, T) => T): T = rdd.reduce(func) + + /** + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given `func` + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) + + /** + * (Scala-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * @since 1.6.0 + */ + def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = { + val inputPlan = logicalPlan + val withGroupingKey = AppendColumns(func, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedDataset( + encoderFor[K], + encoderFor[T], + executed, + inputPlan.output, + withGroupingKey.newColumns) + } + + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * @since 1.6.0 + */ + @scala.annotation.varargs + def groupByKey(cols: Column*): GroupedDataset[Row, T] = { + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) + val withKey = Project(withKeyColumns, logicalPlan) + val executed = sqlContext.executePlan(withKey) + + val dataAttributes = executed.analyzed.output.dropRight(cols.size) + val keyAttributes = executed.analyzed.output.takeRight(cols.size) + + new GroupedDataset( + RowEncoder(keyAttributes.toStructType), + encoderFor[T], + executed, + dataAttributes, + keyAttributes) } /** + * (Java-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. + * @since 1.6.0 + */ + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupByKey(func.call(_))(encoder) + + /** * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, * so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. @@ -910,7 +1127,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def rollup(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType) + GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType) } /** @@ -937,7 +1154,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def cube(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) + GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType) } /** @@ -997,7 +1214,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def limit(n: Int): DataFrame = withPlan { + def limit(n: Int): Dataset[T] = withTypedPlan { Limit(Literal(n), logicalPlan) } @@ -1007,19 +1224,21 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def unionAll(other: DataFrame): DataFrame = withPlan { + def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. CombineUnions(Union(logicalPlan, other.logicalPlan)) } + def union(other: Dataset[T]): Dataset[T] = unionAll(other) + /** * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. * This is equivalent to `INTERSECT` in SQL. * @group dfops * @since 1.3.0 */ - def intersect(other: DataFrame): DataFrame = withPlan { + def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { Intersect(logicalPlan, other.logicalPlan) } @@ -1029,10 +1248,12 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def except(other: DataFrame): DataFrame = withPlan { + def except(other: Dataset[T]): Dataset[T] = withTypedPlan { Except(logicalPlan, other.logicalPlan) } + def subtract(other: Dataset[T]): Dataset[T] = except(other) + /** * Returns a new [[DataFrame]] by sampling a fraction of rows. * @@ -1042,7 +1263,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } @@ -1054,7 +1275,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { sample(withReplacement, fraction, Utils.random.nextLong) } @@ -1066,7 +1287,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the @@ -1075,7 +1296,8 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)()) + new Dataset[T]( + sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) }.toArray } @@ -1086,7 +1308,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def randomSplit(weights: Array[Double]): Array[DataFrame] = { + def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { randomSplit(weights, Utils.random.nextLong) } @@ -1097,7 +1319,7 @@ class DataFrame private[sql]( * @param seed Seed for sampling. * @group dfops */ - private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = { randomSplit(weights.toArray, seed) } @@ -1238,7 +1460,7 @@ class DataFrame private[sql]( } select(columns : _*) } else { - this + toDF() } } @@ -1264,7 +1486,7 @@ class DataFrame private[sql]( val remainingCols = schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) if (remainingCols.size == this.schema.size) { - this + toDF() } else { this.select(remainingCols: _*) } @@ -1297,7 +1519,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(): DataFrame = dropDuplicates(this.columns) + def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns) /** * (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only @@ -1306,7 +1528,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan { + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val groupCols = colNames.map(resolve) val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => @@ -1326,7 +1548,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.4.0 */ - def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq) + def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq) /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. @@ -1396,7 +1618,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df => + def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df => df.collect(needCallback = false) } @@ -1405,14 +1627,14 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def head(): Row = head(1).head + def head(): T = head(1).head /** * Returns the first row. Alias for head(). * @group action * @since 1.3.0 */ - def first(): Row = head() + def first(): T = head() /** * Concise syntax for chaining custom transformations. @@ -1425,27 +1647,113 @@ class DataFrame private[sql]( * }}} * @since 1.6.0 */ - def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = + map(t => func.call(t))(encoder) + + /** + * (Scala-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * @since 1.6.0 + */ + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapPartitions[T, U](func, logicalPlan), + implicitly[Encoder[U]]) + } + + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * @since 1.6.0 + */ + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala + mapPartitions(func)(encoder) + } + + /** + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = + mapPartitions(_.flatMap(func)) + + /** + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterator[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } /** * Applies a function `f` to all rows. * @group rdd * @since 1.3.0 */ - def foreach(f: Row => Unit): Unit = withNewExecutionId { + def foreach(f: T => Unit): Unit = withNewExecutionId { rdd.foreach(f) } /** + * (Java-specific) + * Runs `func` on each element of this [[Dataset]]. + * @since 1.6.0 + */ + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) + + /** * Applies a function f to each partition of this [[DataFrame]]. * @group rdd * @since 1.3.0 */ - def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { + def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId { rdd.foreachPartition(f) } /** + * (Java-specific) + * Runs `func` on each partition of this [[Dataset]]. + * @since 1.6.0 + */ + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + foreachPartition(it => func.call(it.asJava)) + + /** * Returns the first `n` rows in the [[DataFrame]]. * * Running take requires moving data into the application's driver process, and doing so with @@ -1454,7 +1762,11 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def take(n: Int): Array[Row] = head(n) + def take(n: Int): Array[T] = head(n) + + def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds => + ds.collectRows(needCallback = false) + } /** * Returns the first `n` rows in the [[DataFrame]] as a list. @@ -1465,7 +1777,7 @@ class DataFrame private[sql]( * @group action * @since 1.6.0 */ - def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. @@ -1478,7 +1790,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collect(): Array[Row] = collect(needCallback = true) + def collect(): Array[T] = collect(needCallback = true) + + def collectRows(): Array[Row] = collectRows(needCallback = true) /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. @@ -1489,19 +1803,32 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) + val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + java.util.Arrays.asList(values : _*) } } - private def collect(needCallback: Boolean): Array[Row] = { + private def collect(needCallback: Boolean): Array[T] = { + def execute(): Array[T] = withNewExecutionId { + queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) + } + + if (needCallback) { + withCallback("collect", toDF())(_ => execute()) + } else { + execute() + } + } + + private def collectRows(needCallback: Boolean): Array[Row] = { def execute(): Array[Row] = withNewExecutionId { queryExecution.executedPlan.executeCollectPublic() } if (needCallback) { - withCallback("collect", this)(_ => execute()) + withCallback("collect", toDF())(_ => execute()) } else { execute() } @@ -1521,7 +1848,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def repartition(numPartitions: Int): DataFrame = withPlan { + def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { Repartition(numPartitions, shuffle = true, logicalPlan) } @@ -1535,7 +1862,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan { + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions)) } @@ -1549,7 +1876,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): DataFrame = withPlan { + def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) } @@ -1561,7 +1888,7 @@ class DataFrame private[sql]( * @group rdd * @since 1.4.0 */ - def coalesce(numPartitions: Int): DataFrame = withPlan { + def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { Repartition(numPartitions, shuffle = false, logicalPlan) } @@ -1571,7 +1898,7 @@ class DataFrame private[sql]( * @group dfops * @since 1.3.0 */ - def distinct(): DataFrame = dropDuplicates() + def distinct(): Dataset[T] = dropDuplicates() /** * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). @@ -1632,12 +1959,11 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - lazy val rdd: RDD[Row] = { + lazy val rdd: RDD[T] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema queryExecution.toRdd.mapPartitions { rows => - val converter = CatalystTypeConverters.createToScalaConverter(schema) - rows.map(converter(_).asInstanceOf[Row]) + rows.map(boundTEncoder.fromRow) } } @@ -1646,14 +1972,14 @@ class DataFrame private[sql]( * @group rdd * @since 1.3.0 */ - def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd * @since 1.3.0 */ - def javaRDD: JavaRDD[Row] = toJavaRDD + def javaRDD: JavaRDD[T] = toJavaRDD /** * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this @@ -1663,7 +1989,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def registerTempTable(tableName: String): Unit = { - sqlContext.registerDataFrameAsTable(this, tableName) + sqlContext.registerDataFrameAsTable(toDF(), tableName) } /** @@ -1674,7 +2000,7 @@ class DataFrame private[sql]( * @since 1.4.0 */ @Experimental - def write: DataFrameWriter = new DataFrameWriter(this) + def write: DataFrameWriter = new DataFrameWriter(toDF()) /** * Returns the content of the [[DataFrame]] as a RDD of JSON strings. @@ -1745,7 +2071,7 @@ class DataFrame private[sql]( * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[T](body: => T): T = { + private[sql] def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) } @@ -1753,7 +2079,7 @@ class DataFrame private[sql]( * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the * user-registered callback functions. */ - private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { + private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = { try { df.queryExecution.executedPlan.foreach { plan => plan.resetMetrics() @@ -1770,7 +2096,24 @@ class DataFrame private[sql]( } } - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = { + private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = { + try { + ds.queryExecution.executedPlan.foreach { plan => + plan.resetMetrics() + } + val start = System.nanoTime() + val result = action(ds) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, ds.queryExecution, e) + throw e + } + } + + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => @@ -1779,14 +2122,23 @@ class DataFrame private[sql]( SortOrder(expr, Ascending) } } - withPlan { + withTypedPlan { Sort(sortOrder, global = global, logicalPlan) } } /** A convenient function to wrap a logical plan and produce a DataFrame. */ @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { - new DataFrame(sqlContext, logicalPlan) + DataFrame(sqlContext, logicalPlan) + } + + /** A convenient function to wrap a logical plan and produce a DataFrame. */ + @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { + new Dataset[T](sqlContext, logicalPlan, encoder) } + private[sql] def withTypedPlan[R]( + other: Dataset[_], encoder: Encoder[R])( + f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = + new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 509b29956f..822702429d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -345,7 +345,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { InferSchema.infer(jsonRDD, sqlContext.conf.columnNameOfCorruptRecord, parsedOptions) } - new DataFrame( + DataFrame( sqlContext, LogicalRDD( schema.toAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala deleted file mode 100644 index daddf6e0c5..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ /dev/null @@ -1,794 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.collection.JavaConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function._ -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.encoders._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.CombineUnions -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{Queryable, QueryExecution} -import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - -/** - * :: Experimental :: - * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel - * using functional or relational operations. - * - * A [[Dataset]] differs from an [[RDD]] in the following ways: - * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored - * in the encoded form. This representation allows for additional logical operations and - * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to - * an object. - * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be - * used to serialize the object into a binary format. Encoders are also capable of mapping the - * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime - * reflection based serialization. Operations that change the type of object stored in the - * dataset also need an encoder for the new type. - * - * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific - * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into - * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed - * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`. - * - * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However, - * making this change to the class hierarchy would break the function signatures for the existing - * functional operations (map, flatMap, etc). As such, this class should be considered a preview - * of the final API. Changes will be made to the interface after Spark 1.6. - * - * @since 1.6.0 - */ -@Experimental -class Dataset[T] private[sql]( - @transient override val sqlContext: SQLContext, - @transient override val queryExecution: QueryExecution, - tEncoder: Encoder[T]) extends Queryable with Serializable with Logging { - - /** - * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is - * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the - * same object type (that will be possibly resolved to a different schema). - */ - private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) - unresolvedTEncoder.validate(logicalPlan.output) - - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) - - /** - * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of this [[Dataset]]'s output schema. - */ - private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) - - private implicit def classTag = unresolvedTEncoder.clsTag - - private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = - this(sqlContext, new QueryExecution(sqlContext, plan), encoder) - - /** - * Returns the schema of the encoded form of the objects in this [[Dataset]]. - * @since 1.6.0 - */ - override def schema: StructType = resolvedTEncoder.schema - - /** - * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format. - * @since 1.6.0 - */ - override def printSchema(): Unit = toDF().printSchema() - - /** - * Prints the plans (logical and physical) to the console for debugging purposes. - * @since 1.6.0 - */ - override def explain(extended: Boolean): Unit = toDF().explain(extended) - - /** - * Prints the physical plan to the console for debugging purposes. - * @since 1.6.0 - */ - override def explain(): Unit = toDF().explain() - - /* ************* * - * Conversions * - * ************* */ - - /** - * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The - * method used to map columns depend on the type of `U`: - * - When `U` is a class, fields for the class will be mapped to columns of the same name - * (case sensitivity is determined by `spark.sql.caseSensitive`) - * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will - * be assigned to `_1`). - * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the - * [[DataFrame]] will be used. - * - * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select` - * along with `alias` or `as` to rearrange or rename as required. - * @since 1.6.0 - */ - def as[U : Encoder]: Dataset[U] = { - new Dataset(sqlContext, queryExecution, encoderFor[U]) - } - - /** - * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have - * the same name after two Datasets have been joined. - * @since 1.6.0 - */ - def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _)) - - /** - * Converts this strongly typed collection of data to generic Dataframe. In contrast to the - * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] - * objects that allow fields to be accessed by ordinal or name. - */ - // This is declared with parentheses to prevent the Scala compiler from treating - // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) - - /** - * Returns this [[Dataset]]. - * @since 1.6.0 - */ - // This is declared with parentheses to prevent the Scala compiler from treating - // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset. - def toDS(): Dataset[T] = this - - /** - * Converts this [[Dataset]] to an [[RDD]]. - * @since 1.6.0 - */ - def rdd: RDD[T] = { - queryExecution.toRdd.mapPartitions { iter => - iter.map(boundTEncoder.fromRow) - } - } - - /** - * Returns the number of elements in the [[Dataset]]. - * @since 1.6.0 - */ - def count(): Long = toDF().count() - - /** - * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * - * @since 1.6.0 - */ - def show(numRows: Int): Unit = show(numRows, truncate = true) - - /** - * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. - * - * @since 1.6.0 - */ - def show(): Unit = show(20) - - /** - * Displays the top 20 rows of [[Dataset]] in a tabular form. - * - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * @since 1.6.0 - */ - def show(truncate: Boolean): Unit = show(20, truncate) - - /** - * Displays the [[Dataset]] in a tabular form. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * @since 1.6.0 - */ - // scalastyle:off println - def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) - // scalastyle:on println - - /** - * Compose the string representing rows for output - * @param _numRows Number of rows to show - * @param truncate Whether truncate long strings and align cells right - */ - override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { - val numRows = _numRows.max(0) - val takeResult = take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) - - // For array values, replace Seq and Array with square brackets - // For cells that are beyond 20 characters, replace it with the first 17 and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map { - case r: Row => r - case tuple: Product => Row.fromTuple(tuple) - case o => Row(o) - } map { row => - row.toSeq.map { cell => - val str = cell match { - case null => "null" - case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") - case array: Array[_] => array.mkString("[", ", ", "]") - case seq: Seq[_] => seq.mkString("[", ", ", "]") - case _ => cell.toString - } - if (truncate && str.length > 20) str.substring(0, 17) + "..." else str - }: Seq[String] - }) - - formatString ( rows, numRows, hasMoreData, truncate ) - } - - /** - * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. - * @since 1.6.0 - */ - def repartition(numPartitions: Int): Dataset[T] = withPlan { - Repartition(numPartitions, shuffle = true, _) - } - - /** - * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. - * @since 1.6.0 - */ - def coalesce(numPartitions: Int): Dataset[T] = withPlan { - Repartition(numPartitions, shuffle = false, _) - } - - /* *********************** * - * Functional Operations * - * *********************** */ - - /** - * Concise syntax for chaining custom transformations. - * {{{ - * def featurize(ds: Dataset[T]) = ... - * - * dataset - * .transform(featurize) - * .transform(...) - * }}} - * @since 1.6.0 - */ - def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) - - /** - * (Scala-specific) - * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. - * @since 1.6.0 - */ - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) - - /** - * (Java-specific) - * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. - * @since 1.6.0 - */ - def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) - - /** - * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. - * @since 1.6.0 - */ - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) - - /** - * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. - * @since 1.6.0 - */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) - - /** - * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. - * @since 1.6.0 - */ - def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset[U]( - sqlContext, - MapPartitions[T, U](func, logicalPlan)) - } - - /** - * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. - * @since 1.6.0 - */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala - mapPartitions(func)(encoder) - } - - /** - * (Scala-specific) - * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], - * and then flattening the results. - * @since 1.6.0 - */ - def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = - mapPartitions(_.flatMap(func)) - - /** - * (Java-specific) - * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], - * and then flattening the results. - * @since 1.6.0 - */ - def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - val func: (T) => Iterator[U] = x => f.call(x).asScala - flatMap(func)(encoder) - } - - /* ************** * - * Side effects * - * ************** */ - - /** - * (Scala-specific) - * Runs `func` on each element of this [[Dataset]]. - * @since 1.6.0 - */ - def foreach(func: T => Unit): Unit = rdd.foreach(func) - - /** - * (Java-specific) - * Runs `func` on each element of this [[Dataset]]. - * @since 1.6.0 - */ - def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) - - /** - * (Scala-specific) - * Runs `func` on each partition of this [[Dataset]]. - * @since 1.6.0 - */ - def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) - - /** - * (Java-specific) - * Runs `func` on each partition of this [[Dataset]]. - * @since 1.6.0 - */ - def foreachPartition(func: ForeachPartitionFunction[T]): Unit = - foreachPartition(it => func.call(it.asJava)) - - /* ************* * - * Aggregation * - * ************* */ - - /** - * (Scala-specific) - * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` - * must be commutative and associative or the result may be non-deterministic. - * @since 1.6.0 - */ - def reduce(func: (T, T) => T): T = rdd.reduce(func) - - /** - * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given `func` - * must be commutative and associative or the result may be non-deterministic. - * @since 1.6.0 - */ - def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) - - /** - * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. - * @since 1.6.0 - */ - def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { - val inputPlan = logicalPlan - val withGroupingKey = AppendColumns(func, inputPlan) - val executed = sqlContext.executePlan(withGroupingKey) - - new GroupedDataset( - encoderFor[K], - encoderFor[T], - executed, - inputPlan.output, - withGroupingKey.newColumns) - } - - /** - * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. - * @since 1.6.0 - */ - @scala.annotation.varargs - def groupBy(cols: Column*): GroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) - val withKey = Project(withKeyColumns, logicalPlan) - val executed = sqlContext.executePlan(withKey) - - val dataAttributes = executed.analyzed.output.dropRight(cols.size) - val keyAttributes = executed.analyzed.output.takeRight(cols.size) - - new GroupedDataset( - RowEncoder(keyAttributes.toStructType), - encoderFor[T], - executed, - dataAttributes, - keyAttributes) - } - - /** - * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. - * @since 1.6.0 - */ - def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = - groupBy(func.call(_))(encoder) - - /* ****************** * - * Typed Relational * - * ****************** */ - - /** - * Returns a new [[DataFrame]] by selecting a set of column based expressions. - * {{{ - * df.select($"colA", $"colB" + 1) - * }}} - * @since 1.6.0 - */ - // Copied from Dataframe to make sure we don't have invalid overloads. - @scala.annotation.varargs - protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) - - /** - * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. - * - * {{{ - * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(expr("value + 1").as[Int]) - * }}} - * @since 1.6.0 - */ - def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1]( - sqlContext, - Project( - c1.withInputType( - boundTEncoder, - logicalPlan.output).named :: Nil, - logicalPlan)) - } - - /** - * Internal helper function for building typed selects that return tuples. For simplicity and - * code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - */ - protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) - val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) - val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) - - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) - } - - /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 - */ - def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = - selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] - - /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 - */ - def select[U1, U2, U3]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = - selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] - - /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 - */ - def select[U1, U2, U3, U4]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = - selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] - - /** - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - * @since 1.6.0 - */ - def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[T, U1], - c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4], - c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = - selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] - - /** - * Returns a new [[Dataset]] by sampling a fraction of records. - * @since 1.6.0 - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = - withPlan(Sample(0.0, fraction, withReplacement, seed, _)()) - - /** - * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. - * @since 1.6.0 - */ - def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = { - sample(withReplacement, fraction, Utils.random.nextLong) - } - - /* **************** * - * Set operations * - * **************** */ - - /** - * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]]. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * @since 1.6.0 - */ - def distinct: Dataset[T] = withPlan(Distinct) - - /** - * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also - * present in `other`. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * @since 1.6.0 - */ - def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) - - /** - * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] - * combined. - * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analogous to `UNION ALL` in SQL. - * @since 1.6.0 - */ - def union(other: Dataset[T]): Dataset[T] = withPlan[T](other) { (left, right) => - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(left, right)) - } - - /** - * Returns a new [[Dataset]] where any elements present in `other` have been removed. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * @since 1.6.0 - */ - def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) - - /* ****** * - * Joins * - * ****** */ - - /** - * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to - * true. - * - * This is similar to the relation `join` function with one important difference in the - * result schema. Since `joinWith` preserves objects present on either side of the join, the - * result schema is similarly nested into a tuple under the column names `_1` and `_2`. - * - * This type of join can be useful both for preserving type-safety with the original object - * types as well as working with relational data where either side of the join has column - * names in common. - * - * @param other Right side of the join. - * @param condition Join expression. - * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. - * @since 1.6.0 - */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan - - val joined = sqlContext.executePlan(Join(left, right, joinType = - JoinType(joinType), Some(condition.expr))) - val leftOutput = joined.analyzed.output.take(left.output.length) - val rightOutput = joined.analyzed.output.takeRight(right.output.length) - - val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(leftOutput.head, "_1")() - case _ => Alias(CreateStruct(leftOutput), "_1")() - } - val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(rightOutput.head, "_2")() - case _ => Alias(CreateStruct(rightOutput), "_2")() - } - - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withPlan[(T, U)](other) { (left, right) => - Project( - leftData :: rightData :: Nil, - joined.analyzed) - } - } - - /** - * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair - * where `condition` evaluates to true. - * - * @param other Right side of the join. - * @param condition Join expression. - * @since 1.6.0 - */ - def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { - joinWith(other, condition, "inner") - } - - /* ************************** * - * Gather to Driver Actions * - * ************************** */ - - /** - * Returns the first element in this [[Dataset]]. - * @since 1.6.0 - */ - def first(): T = take(1).head - - /** - * Returns an array that contains all the elements in this [[Dataset]]. - * - * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. - * - * For Java API, use [[collectAsList]]. - * @since 1.6.0 - */ - def collect(): Array[T] = { - // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders - // to convert the rows into objects of type T. - queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow) - } - - /** - * Returns an array that contains all the elements in this [[Dataset]]. - * - * Running collect requires moving all the data into the application's driver process, and - * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError. - * - * For Java API, use [[collectAsList]]. - * @since 1.6.0 - */ - def collectAsList(): java.util.List[T] = collect().toSeq.asJava - - /** - * Returns the first `num` elements of this [[Dataset]] as an array. - * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `num` can crash the driver process with OutOfMemoryError. - * @since 1.6.0 - */ - def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect() - - /** - * Returns the first `num` elements of this [[Dataset]] as an array. - * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `num` can crash the driver process with OutOfMemoryError. - * @since 1.6.0 - */ - def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) - - /** - * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). - * @since 1.6.0 - */ - def persist(): this.type = { - sqlContext.cacheManager.cacheQuery(this) - this - } - - /** - * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). - * @since 1.6.0 - */ - def cache(): this.type = persist() - - /** - * Persist this [[Dataset]] with the given storage level. - * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, - * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, - * `MEMORY_AND_DISK_2`, etc. - * @group basic - * @since 1.6.0 - */ - def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheManager.cacheQuery(this, None, newLevel) - this - } - - /** - * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. - * @param blocking Whether to block until all blocks are deleted. - * @since 1.6.0 - */ - def unpersist(blocking: Boolean): this.type = { - sqlContext.cacheManager.tryUncacheQuery(this, blocking) - this - } - - /** - * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. - * @since 1.6.0 - */ - def unpersist(): this.type = unpersist(blocking = false) - - /* ******************** * - * Internal Functions * - * ******************** */ - - private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed - - private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) - - private[sql] def withPlan[R : Encoder]( - other: Dataset[_])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a7258d742a..2a0f77349a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.NumericType /** * :: Experimental :: - * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. * * The main method is the agg function, which has multiple variants. This class also contains * convenience some first order statistics such as mean, sum for convenience. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index cd8ed472ec..1639cc8db6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -64,7 +64,7 @@ class GroupedDataset[K, V] private[sql]( private def groupedData = new GroupedData( - new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified @@ -86,7 +86,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def keys: Dataset[K] = { - new Dataset[K]( + Dataset[K]( sqlContext, Distinct( Project(groupingAttributes, logicalPlan))) @@ -111,7 +111,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = { - new Dataset[U]( + Dataset[U]( sqlContext, MapGroups( f, @@ -308,7 +308,7 @@ class GroupedDataset[K, V] private[sql]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit val uEncoder = other.unresolvedVEncoder - new Dataset[R]( + Dataset[R]( sqlContext, CoGroup( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c742bf2f89..54dbd6bda5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -464,7 +464,7 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) - new Dataset[T](this, plan) + Dataset[T](this, plan) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { @@ -473,7 +473,7 @@ class SQLContext private[sql]( val encoded = data.map(d => enc.toRow(d)) val plan = LogicalRDD(attributes, encoded)(self) - new Dataset[T](this, plan) + Dataset[T](this, plan) } def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 16c4095db7..e23d5e1261 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -126,6 +126,7 @@ abstract class SQLImplicits { /** * Creates a [[Dataset]] from an RDD. + * * @since 1.6.0 */ implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8616fe3170..19ab3ea132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} @@ -31,7 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed) + def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch { + case e: AnalysisException => + throw new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) + } lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index e048ee1441..60ec67c8f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -154,7 +154,7 @@ case class DataSource( } def dataFrameBuilder(files: Array[String]): DataFrame = { - new DataFrame( + DataFrame( sqlContext, LogicalRelation( DataSource( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index a191759813..0dc34814fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging { StructField(v._1 + "_freqItems", ArrayType(v._2, false)) } val schema = StructType(outputCols).toAttributes - new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 26e4eda542..daa065e5cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging { } val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bc7c520930..7d7c51b158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -211,7 +211,7 @@ class StreamExecution( // Construct the batch and send it to the sink. val batchOffset = streamProgress.toCompositeOffset(sources) - val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan)) + val nextBatch = new Batch(batchOffset, DataFrame(sqlContext, newPlan)) sink.addBatch(nextBatch) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 8124df15af..3b764c5558 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -55,11 +55,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { - new Dataset(sqlContext, logicalPlan) + Dataset(sqlContext, logicalPlan) } def toDF()(implicit sqlContext: SQLContext): DataFrame = { - new DataFrame(sqlContext, logicalPlan) + DataFrame(sqlContext, logicalPlan) } def addData(data: A*): Offset = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 6eea924517..844f3051fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -46,7 +46,6 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @tparam I The input type for the aggregation. * @tparam B The type of the intermediate value of the reduction. * @tparam O The type of the final output result. - * * @since 1.6.0 */ abstract class Aggregator[-I, B, O] extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index bd73a36fd4..97e35bb104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -42,4 +42,5 @@ package object sql { @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + type DataFrame = Dataset[Row] } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 51f987fda9..42af813bc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); + Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() { @Override @@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - DataFrame df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ee85626435..47cc74dbc1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -64,13 +64,13 @@ public class JavaDataFrameSuite { @Test public void testExecution() { - DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collect()[0].get(0)); + Dataset<Row> df = context.table("testData").filter("key = 1"); + Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); } @Test public void testCollectAndTake() { - DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -80,7 +80,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.toDF("key1", "value1"); @@ -109,7 +109,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - DataFrame df2 = context.table("testData2"); + Dataset<Row> df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -123,7 +123,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - DataFrame df = context.table("testData"); + Dataset<Row> df = context.table("testData"); df.show(); df.show(1000); } @@ -151,7 +151,7 @@ public class JavaDataFrameSuite { } } - void validateDataFrameWithBeans(Bean bean, DataFrame df) { + void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -191,7 +191,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - DataFrame df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -199,7 +199,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -207,8 +207,8 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - DataFrame df = context.createDataFrame(rows, schema); - Row[] result = df.collect(); + Dataset<Row> df = context.createDataFrame(rows, schema); + Row[] result = df.collectRows(); Assert.assertEquals(1, result.length); } @@ -235,13 +235,13 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - DataFrame df = context.table("testData2"); - DataFrame crosstab = df.stat().crosstab("a", "b"); + Dataset<Row> df = context.table("testData2"); + Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collect(); + Row[] rows = crosstab.collectRows(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { @@ -254,31 +254,31 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); String[] cols = {"a"}; - DataFrame results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); + Dataset<Row> results = df.stat().freqItems(cols, 0.2); + Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); } @Test public void testCorrelation() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - DataFrame df = context.table("testData2"); + Dataset<Row> df = context.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); - DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); Assert.assertEquals(0, actual[0].getLong(0)); Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); Assert.assertEquals(1, actual[1].getLong(0)); @@ -287,10 +287,10 @@ public class JavaDataFrameSuite { @Test public void pivot() { - DataFrame df = context.table("courseSales"); + Dataset<Row> df = context.table("courseSales"); Row[] actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collect(); + .agg(sum("earnings")).orderBy("year").collectRows(); Assert.assertEquals(2012, actual[0].getInt(0)); Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); @@ -303,11 +303,11 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - DataFrame df1 = context.read().format("text").load( + Dataset<Row> df1 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().format("text").load( + Dataset<Row> df2 = context.read().format("text").load( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -315,11 +315,11 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - DataFrame df1 = context.read().text( + Dataset<Row> df1 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString()); Assert.assertEquals(4L, df1.count()); - DataFrame df2 = context.read().text( + Dataset<Row> df2 = context.read().text( Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(), Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString()); Assert.assertEquals(5L, df2.count()); @@ -327,7 +327,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -352,7 +352,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - DataFrame df = context.range(1000); + Dataset<Row> df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index b054b1095b..79b6e61767 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable { public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { + GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable { List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { + GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = - ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); + ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); - GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( + GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { @@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 9e241f2098..0f9e453d26 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -42,9 +42,9 @@ public class JavaSaveLoadSuite { String originalDefaultSource; File path; - DataFrame df; + Dataset<Row> df; - private static void checkAnswer(DataFrame actual, List<Row> expected) { + private static void checkAnswer(Dataset<Row> actual, List<Row> expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -85,7 +85,7 @@ public class JavaSaveLoadSuite { Map<String, String> options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -98,7 +98,7 @@ public class JavaSaveLoadSuite { List<StructField> fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 26775c3700..f4a5107eaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -38,23 +38,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("analysis error should be eagerly reported") { - // Eager analysis. - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) - } + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) } - - // No more eager analysis once the flag is turned off - withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { - testData.select('nonExistentName) + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) } } @@ -72,7 +64,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Nil) } - test("invalid plan toString, debug mode") { + ignore("invalid plan toString, debug mode") { // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ @@ -941,7 +933,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") + DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3258f3782d..84770169f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -119,16 +119,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum(_._2)), + checkDataset( + ds.groupByKey(_._1).agg(sum(_._2)), ("a", 30), ("b", 3), ("c", 1)) } test("typed aggregation: TypedAggregator, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( sum(_._2), expr("sum(_2)").as[Long], count("*")), @@ -138,8 +138,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: complex case") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( expr("avg(_2)").as[Double], TypedAverage.toColumn), ("a", 2.0, 2.0), ("b", 3.0, 3.0)) @@ -148,8 +148,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: complex result type") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( expr("avg(_2)").as[Double], ComplexResultAgg.toColumn), ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) @@ -158,10 +158,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: in project list") { val ds = Seq(1, 3, 2, 5).toDS() - checkAnswer( + checkDataset( ds.select(sum((i: Int) => i)), 11) - checkAnswer( + checkDataset( ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), 11 -> 22) } @@ -169,7 +169,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: class input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() - checkAnswer( + checkDataset( ds.select(ClassInputAgg.toColumn), 3) } @@ -177,33 +177,33 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("typed aggregation: class input with reordering") { val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] - checkAnswer( + checkDataset( ds.select(ClassInputAgg.toColumn), 1) - checkAnswer( + checkDataset( ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), (1.0, 1)) - checkAnswer( - ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + checkDataset( + ds.groupByKey(_.b).agg(ClassInputAgg.toColumn), ("one", 1)) } test("typed aggregation: complex input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() - checkAnswer( + checkDataset( ds.select(ComplexBufferAgg.toColumn), 2 ) - checkAnswer( + checkDataset( ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), (1.5, 2)) - checkAnswer( - ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn), + checkDataset( + ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 848f1af655..2e5179a8d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -34,7 +34,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { // Make sure, the Dataset is indeed cached. assertCached(cached) // Check result. - checkAnswer( + checkDataset( cached, 2, 3, 4) // Drop the cache. @@ -52,7 +52,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(ds2) val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) + checkDataset(joined, ("2", 2)) assertCached(joined, 2) ds1.unpersist() @@ -63,11 +63,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { test("persist and then groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").keyAs[String] + val grouped = ds.groupByKey($"_1").keyAs[String] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } agged.persist() - checkAnswer( + checkDataset( agged.filter(_._1 == "b"), ("b", 3)) assertCached(agged.filter(_._1 == "b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 243d13b19d..6e9840e4a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -28,14 +28,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("toDS") { val data = Seq(1, 2, 3, 4, 5, 6) - checkAnswer( + checkDataset( data.toDS(), data: _*) } test("as case class / collect") { val ds = Seq(1, 2, 3).toDS().as[IntClass] - checkAnswer( + checkDataset( ds, IntClass(1), IntClass(2), IntClass(3)) @@ -44,14 +44,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("map") { val ds = Seq(1, 2, 3).toDS() - checkAnswer( + checkDataset( ds.map(_ + 1), 2, 3, 4) } test("filter") { val ds = Seq(1, 2, 3, 4).toDS() - checkAnswer( + checkDataset( ds.filter(_ % 2 == 0), 2, 4) } @@ -77,54 +77,54 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() - val grouped = ds.groupBy(_ % 2) - checkAnswer( + val grouped = ds.groupByKey(_ % 2) + checkDataset( grouped.keys, 0, 1) } test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() - val grouped = ds.groupBy(_ % 2) + val grouped = ds.groupByKey(_ % 2) val agged = grouped.mapGroups { case (g, iter) => val name = if (g == 0) "even" else "odd" (name, iter.size) } - checkAnswer( + checkDataset( agged, ("even", 5), ("odd", 6)) } test("groupBy function, flatMap") { val ds = Seq("a", "b", "c", "xyz", "hello").toDS() - val grouped = ds.groupBy(_.length) + val grouped = ds.groupByKey(_.length) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } - checkAnswer( + checkDataset( agged, "1", "abc", "3", "xyz", "5", "hello") } test("Arrays and Lists") { - checkAnswer(Seq(Seq(1)).toDS(), Seq(1)) - checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) - checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) - checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) - checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) - checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) - checkAnswer(Seq(Seq(true)).toDS(), Seq(true)) - checkAnswer(Seq(Seq("test")).toDS(), Seq("test")) - checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) - - checkAnswer(Seq(Array(1)).toDS(), Array(1)) - checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) - checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) - checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) - checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) - checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) - checkAnswer(Seq(Array(true)).toDS(), Array(true)) - checkAnswer(Seq(Array("test")).toDS(), Array("test")) - checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) + checkDataset(Seq(Seq(1)).toDS(), Seq(1)) + checkDataset(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong)) + checkDataset(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble)) + checkDataset(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat)) + checkDataset(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte)) + checkDataset(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort)) + checkDataset(Seq(Seq(true)).toDS(), Seq(true)) + checkDataset(Seq(Seq("test")).toDS(), Seq("test")) + checkDataset(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1))) + + checkDataset(Seq(Array(1)).toDS(), Array(1)) + checkDataset(Seq(Array(1.toLong)).toDS(), Array(1.toLong)) + checkDataset(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble)) + checkDataset(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat)) + checkDataset(Seq(Array(1.toByte)).toDS(), Array(1.toByte)) + checkDataset(Seq(Array(1.toShort)).toDS(), Array(1.toShort)) + checkDataset(Seq(Array(true)).toDS(), Array(true)) + checkDataset(Seq(Array("test")).toDS(), Array("test")) + checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 79e10215f4..9f32c8bf95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,14 +34,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("toDS") { val data = Seq(("a", 1), ("b", 2), ("c", 3)) - checkAnswer( + checkDataset( data.toDS(), data: _*) } test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() - checkAnswer( + checkDataset( ds.mapPartitions(_ => Iterator(1)), 1, 1, 1) } @@ -71,26 +71,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = data.toDS() assert(ds.repartition(10).rdd.partitions.length == 10) - checkAnswer( + checkDataset( ds.repartition(10), data: _*) assert(ds.coalesce(1).rdd.partitions.length == 1) - checkAnswer( + checkDataset( ds.coalesce(1), data: _*) } test("as tuple") { val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") - checkAnswer( + checkDataset( data.as[(String, Int)], ("a", 1), ("b", 2)) } test("as case class / collect") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] - checkAnswer( + checkDataset( ds, ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) assert(ds.collect().head == ClassData("a", 1)) @@ -108,7 +108,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.map(v => (v._1, v._2 + 1)), ("a", 2), ("b", 3), ("c", 4)) } @@ -116,7 +116,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map with type change with the exact matched number of attributes") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.map(identity[(String, Int)]) .as[OtherTuple] .map(identity[OtherTuple]), @@ -126,7 +126,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("map with type change with less attributes") { val ds = Seq(("a", 1, 3), ("b", 2, 4), ("c", 3, 5)).toDS() - checkAnswer( + checkDataset( ds.as[OtherTuple] .map(identity[OtherTuple]), OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) @@ -137,23 +137,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() .map(c => ClassData(c.a, c.b + 1)) - .groupBy(p => p).count() + .groupByKey(p => p).count() - checkAnswer( + checkDataset( ds, (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select(expr("_2 + 1").as[Int]), 2, 3, 4) } test("select 2") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("_2").as[Int]) : Dataset[(String, Int)], @@ -162,7 +162,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and tuple") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("struct(_2, _2)").as[(Int, Int)]), @@ -171,7 +171,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and class") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.select( expr("_1").as[String], expr("named_struct('a', _1, 'b', _2)").as[ClassData]), @@ -189,7 +189,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("filter") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkAnswer( + checkDataset( ds.filter(_._1 == "b"), ("b", 2)) } @@ -217,7 +217,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), (1, 1), (2, 2)) } @@ -230,7 +230,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(("a", new Integer(1)), ("b", new Integer(2))).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"_1" === $"a", "outer"), (ClassNullableData("a", 1), ("a", new Integer(1))), (ClassNullableData("c", 3), (nullString, nullInteger)), @@ -241,7 +241,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"value" === $"_2"), (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) } @@ -260,7 +260,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") - checkAnswer( + checkDataset( ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), ((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))) @@ -268,48 +268,48 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() - val grouped = ds.groupBy(v => (1, v._2)) - checkAnswer( + val grouped = ds.groupByKey(v => (1, v._2)) + checkDataset( grouped.keys, (1, 1)) } test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy(v => (v._1, "word")) + val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy function, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy(v => (v._1, "word")) + val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } - checkAnswer( + checkDataset( agged, "a", "30", "b", "3", "c", "1") } test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() - val agged = ds.groupBy(_.length).reduce(_ + _) + val agged = ds.groupByKey(_.length).reduce(_ + _) - checkAnswer( + checkDataset( agged, 3 -> "abcxyz", 5 -> "hello") } test("groupBy single field class, count") { val ds = Seq("abc", "xyz", "hello").toDS() - val count = ds.groupBy(s => Tuple1(s.length)).count() + val count = ds.groupByKey(s => Tuple1(s.length)).count() - checkAnswer( + checkDataset( count, (Tuple1(3), 2L), (Tuple1(5), 1L) ) @@ -317,49 +317,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1") + val grouped = ds.groupByKey($"_1") val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy columns, count") { val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() - val count = ds.groupBy($"_1").count() + val count = ds.groupByKey($"_1").count() - checkAnswer( + checkDataset( count, (Row("a"), 2L), (Row("b"), 1L)) } test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1").keyAs[String] + val grouped = ds.groupByKey($"_1").keyAs[String] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, ("a", 30), ("b", 3), ("c", 1)) } test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)] + val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] + val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData] val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } - checkAnswer( + checkDataset( agged, (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) } @@ -367,32 +367,32 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long]), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long]), ("a", 30L), ("b", 3L), ("c", 1L)) } test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), + checkDataset( + ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } test("typed aggregation: expr, expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkAnswer( - ds.groupBy(_._1).agg( + checkDataset( + ds.groupByKey(_._1).agg( sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*").as[Long], @@ -403,11 +403,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() - val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) } - checkAnswer( + checkDataset( cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } @@ -415,11 +415,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("cogroup with complex data") { val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS() val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS() - val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) => Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) } - checkAnswer( + checkDataset( cogrouped, 1 -> "a", 2 -> "bc", 3 -> "d") } @@ -427,7 +427,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("sample with replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() - checkAnswer( + checkDataset( data.sample(withReplacement = true, 0.05, seed = 13), 5, 10, 52, 73) } @@ -435,7 +435,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("sample without replacement") { val n = 100 val data = sparkContext.parallelize(1 to n, 2).toDS() - checkAnswer( + checkDataset( data.sample(withReplacement = false, 0.05, seed = 13), 3, 17, 27, 58, 62) } @@ -445,13 +445,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds2 = Seq(2, 3).toDS().as("b") val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") - checkAnswer(joined, ("2", 2)) + checkDataset(joined, ("2", 2)) } test("self join") { val ds = Seq("1", "2").toDS().as("a") val joined = ds.joinWith(ds, lit(true)) - checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) } test("toString") { @@ -477,7 +477,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSet == + assert(ds.groupByKey(p => p).count().collect().toSet == Set((KryoData(1), 1L), (KryoData(2), 1L))) } @@ -496,7 +496,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() - assert(ds.groupBy(p => p).count().collect().toSeq == + assert(ds.groupByKey(p => p).count().collect().toSeq == Seq((JavaData(1), 1L), (JavaData(2), 1L))) } @@ -516,7 +516,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() - checkAnswer( + checkDataset( ds1.joinWith(ds2, lit(true)), ((nullInt, "1"), (nullInt, "1")), ((new java.lang.Integer(22), "2"), (nullInt, "1")), @@ -550,7 +550,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] } - checkAnswer( + checkDataset( buildDataset(Row(Row("hello", 1))), NestedStruct(ClassData("hello", 1)) ) @@ -567,11 +567,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-12478: top level null field") { val ds0 = Seq(NestedStruct(null)).toDS() - checkAnswer(ds0, NestedStruct(null)) + checkDataset(ds0, NestedStruct(null)) checkAnswer(ds0.toDF(), Row(null)) val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() - checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkDataset(ds1, DeepNestedStruct(NestedStruct(null))) checkAnswer(ds1.toDF(), Row(Row(null))) } @@ -579,26 +579,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val outer = new OuterClass OuterScopes.addOuterScope(outer) val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() - checkAnswer(ds.map(_.a), "1", "2") + checkDataset(ds.map(_.a), "1", "2") } test("grouping key and grouped value has field with same name") { val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() - val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups { + val agged = ds.groupByKey(d => ClassNullableData(d.a, null)).mapGroups { case (key, values) => key.a + values.map(_.b).sum } - checkAnswer(agged, "a3") + checkDataset(agged, "a3") } test("cogroup's left and right side has field with same name") { val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() - val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) { + val cogrouped = left.groupByKey(_.a).cogroup(right.groupByKey(_.a)) { case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) } - checkAnswer(cogrouped, "a13", "b24") + checkDataset(cogrouped, "a13", "b24") } test("give nice error message when the real number of fields doesn't match encoder schema") { @@ -626,13 +626,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-13440: Resolving option fields") { val df = Seq(1, 2, 3).toDS() val ds = df.as[Option[Int]] - checkAnswer( + checkDataset( ds.filter(_ => true), Some(1), Some(2), Some(3)) } test("SPARK-13540 Dataset of nested class defined in Scala object") { - checkAnswer( + checkDataset( Seq(OuterObject.InnerClass("foo")).toDS(), OuterObject.InnerClass("foo")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index c05aa5486a..855295d5f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -72,7 +72,7 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T]( + protected def checkDataset[T]( ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( @@ -123,17 +123,17 @@ abstract class QueryTest extends PlanTest { protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { case ae: AnalysisException => - val currentValue = sqlContext.conf.dataFrameEagerAnalysis - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - val partiallyAnalzyedPlan = df.queryExecution.analyzed - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) - fail( - s""" - |Failed to analyze query: $ae - |$partiallyAnalzyedPlan - | - |${stackTraceToString(ae)} - |""".stripMargin) + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } } checkJsonFormat(analyzedDF) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index bb5135826e..493a5a6437 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -65,9 +65,9 @@ import org.apache.spark.sql.execution.streaming._ trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { - def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s)) + def toDF(): DataFrame = DataFrame(sqlContext, StreamingRelation(s)) - def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s)) + def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ @@ -168,10 +168,6 @@ trait StreamTest extends QueryTest with Timeouts { } } - /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */ - def testStream(stream: Dataset[_])(actions: StreamAction*): Unit = - testStream(stream.toDF())(actions: _*) - /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -179,7 +175,8 @@ trait StreamTest extends QueryTest with Timeouts { * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ - def testStream(stream: DataFrame)(actions: StreamAction*): Unit = { + def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + val stream = _stream.toDF() var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null |