diff options
9 files changed, 56 insertions, 40 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 68e9c50d60..42eafcb0f5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -317,6 +317,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 19ec6fcc5d..43e9baece2 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,6 +315,8 @@ class Column(object): sc = SparkContext._active_spark_context return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") + @ignore_unicode_prefix @since(1.3) def cast(self, dataType): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7e1854c43b..5cfc348a69 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -911,14 +911,24 @@ class DataFrame(object): """ return self.groupBy().agg(*exprs) + @since(2.0) + def union(self, other): + """ Return a new :class:`DataFrame` containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union + (that does deduplication of elements), use this function followed by a distinct. + """ + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + @since(1.3) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. - This is equivalent to `UNION ALL` in SQL. + .. note:: Deprecated in 2.0, use union instead. """ - return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + return self.union(other) @since(1.3) def intersect(self, other): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 622a62abad..d64736e111 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -856,7 +856,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def alias(alias: String): Column = as(alias) + def alias(alias: String): Column = name(alias) /** * Gives the column an alias. @@ -871,12 +871,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = withExpr { - expr match { - case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias)() - } - } + def as(alias: String): Column = name(alias) /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -937,6 +932,26 @@ class Column(protected[sql] val expr: Expression) extends Logging { } /** + * Gives the column a name (alias). + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".name("colB")) + * }}} + * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * + * @group expr_ops + * @since 2.0.0 + */ + def name(alias: String): Column = withExpr { + expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } + } + + /** * Casts the column to a different data type. * {{{ * // Casts colA to IntegerType. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index be0dfe7c33..31864d63ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1350,20 +1350,24 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan { - // This breaks caching, but it's usually ok because it addresses a very specific use case: - // using union to union many files or partitions. - CombineUnions(Union(logicalPlan, other.logicalPlan)) - } + @deprecated("use union()", "2.0.0") + def unionAll(other: Dataset[T]): Dataset[T] = union(other) /** * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. * This is equivalent to `UNION ALL` in SQL. * + * To do a SQL-style set union (that does deduplication of elements), use this function followed + * by a [[distinct]]. + * * @group typedrel * @since 2.0.0 */ - def union(other: Dataset[T]): Dataset[T] = unionAll(other) + def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + // This breaks caching, but it's usually ok because it addresses a very specific use case: + // using union to union many files or partitions. + CombineUnions(Union(logicalPlan, other.logicalPlan)) + } /** * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset. @@ -1394,18 +1398,6 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * - * @group typedrel - * @since 2.0.0 - */ - def subtract(other: Dataset[T]): Dataset[T] = except(other) - - /** * Returns a new [[Dataset]] by sampling a fraction of rows. * * @param withReplacement Sample with replacement or not. @@ -1756,7 +1748,7 @@ class Dataset[T] private[sql]( outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } - val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f0f96825e2..8bb75bf2bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -190,7 +190,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def reduce(f: (V, V) => V): Dataset[(K, V)] = { + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) @@ -203,15 +203,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduce(f.call _) + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { + reduceGroups(f.call _) } - // This is here to prevent us from adding overloads that would be ambiguous. - @scala.annotation.varargs - private def agg(exprs: Column*): DataFrame = - groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) - private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3bff129ae2..18f17a85a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -204,7 +204,7 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); - Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() { + Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() { @Override public String call(String v1, String v2) throws Exception { return v1 + v2; @@ -300,7 +300,7 @@ public class JavaDatasetSuite implements Serializable { Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"), unioned.collectAsList()); - Dataset<String> subtracted = ds.subtract(ds2); + Dataset<String> subtracted = ds.except(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index c2434e46f7..351b03b38b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -105,10 +105,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row("a") :: Nil) } - test("alias") { + test("alias and name") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") assert(df.select(df("a").alias("b")).columns.head === "b") + assert(df.select(df("a").name("b")).columns.head === "b") } test("as propagates metadata") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 677f84eb60..0bcc512d71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -305,7 +305,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() - val agged = ds.groupByKey(_.length).reduce(_ + _) + val agged = ds.groupByKey(_.length).reduceGroups(_ + _) checkDataset( agged, |