aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-03-22 23:43:09 -0700
committerReynold Xin <rxin@databricks.com>2016-03-22 23:43:09 -0700
commit926a93e54b83f1ee596096f3301fef015705b627 (patch)
tree97817dcf1069bcc8f148f996873bef5bb6643126
parent1a22cf1e9b6447005c9a329856d734d80a496a06 (diff)
downloadspark-926a93e54b83f1ee596096f3301fef015705b627.tar.gz
spark-926a93e54b83f1ee596096f3301fef015705b627.tar.bz2
spark-926a93e54b83f1ee596096f3301fef015705b627.zip
[SPARK-14088][SQL] Some Dataset API touch-up
## What changes were proposed in this pull request? 1. Deprecated unionAll. It is pretty confusing to have both "union" and "unionAll" when the two do the same thing in Spark but are different in SQL. 2. Rename reduce in KeyValueGroupedDataset to reduceGroups so it is more consistent with rest of the functions in KeyValueGroupedDataset. Also makes it more obvious what "reduce" and "reduceGroups" mean. Previously it was confusing because it could be reducing a Dataset, or just reducing groups. 3. Added a "name" function, which is more natural to name columns than "as" for non-SQL users. 4. Remove "subtract" function since it is just an alias for "except". ## How was this patch tested? All changes should be covered by existing tests. Also added couple test cases to cover "name". Author: Reynold Xin <rxin@databricks.com> Closes #11908 from rxin/SPARK-14088.
-rw-r--r--project/MimaExcludes.scala1
-rw-r--r--python/pyspark/sql/column.py2
-rw-r--r--python/pyspark/sql/dataframe.py14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala11
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala2
9 files changed, 56 insertions, 40 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 68e9c50d60..42eafcb0f5 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -317,6 +317,7 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 19ec6fcc5d..43e9baece2 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -315,6 +315,8 @@ class Column(object):
sc = SparkContext._active_spark_context
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
+ name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
+
@ignore_unicode_prefix
@since(1.3)
def cast(self, dataType):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7e1854c43b..5cfc348a69 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -911,14 +911,24 @@ class DataFrame(object):
"""
return self.groupBy().agg(*exprs)
+ @since(2.0)
+ def union(self, other):
+ """ Return a new :class:`DataFrame` containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union
+ (that does deduplication of elements), use this function followed by a distinct.
+ """
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+
@since(1.3)
def unionAll(self, other):
""" Return a new :class:`DataFrame` containing union of rows in this
frame and another frame.
- This is equivalent to `UNION ALL` in SQL.
+ .. note:: Deprecated in 2.0, use union instead.
"""
- return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+ return self.union(other)
@since(1.3)
def intersect(self, other):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 622a62abad..d64736e111 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -856,7 +856,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.4.0
*/
- def alias(alias: String): Column = as(alias)
+ def alias(alias: String): Column = name(alias)
/**
* Gives the column an alias.
@@ -871,12 +871,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: String): Column = withExpr {
- expr match {
- case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
- case other => Alias(other, alias)()
- }
- }
+ def as(alias: String): Column = name(alias)
/**
* (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -937,6 +932,26 @@ class Column(protected[sql] val expr: Expression) extends Logging {
}
/**
+ * Gives the column a name (alias).
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select($"colA".name("colB"))
+ * }}}
+ *
+ * If the current column has metadata associated with it, this metadata will be propagated
+ * to the new column. If this not desired, use `as` with explicitly empty metadata.
+ *
+ * @group expr_ops
+ * @since 2.0.0
+ */
+ def name(alias: String): Column = withExpr {
+ expr match {
+ case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
+ case other => Alias(other, alias)()
+ }
+ }
+
+ /**
* Casts the column to a different data type.
* {{{
* // Casts colA to IntegerType.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index be0dfe7c33..31864d63ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1350,20 +1350,24 @@ class Dataset[T] private[sql](
* @group typedrel
* @since 2.0.0
*/
- def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan {
- // This breaks caching, but it's usually ok because it addresses a very specific use case:
- // using union to union many files or partitions.
- CombineUnions(Union(logicalPlan, other.logicalPlan))
- }
+ @deprecated("use union()", "2.0.0")
+ def unionAll(other: Dataset[T]): Dataset[T] = union(other)
/**
* Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset.
* This is equivalent to `UNION ALL` in SQL.
*
+ * To do a SQL-style set union (that does deduplication of elements), use this function followed
+ * by a [[distinct]].
+ *
* @group typedrel
* @since 2.0.0
*/
- def union(other: Dataset[T]): Dataset[T] = unionAll(other)
+ def union(other: Dataset[T]): Dataset[T] = withTypedPlan {
+ // This breaks caching, but it's usually ok because it addresses a very specific use case:
+ // using union to union many files or partitions.
+ CombineUnions(Union(logicalPlan, other.logicalPlan))
+ }
/**
* Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset.
@@ -1394,18 +1398,6 @@ class Dataset[T] private[sql](
}
/**
- * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset.
- * This is equivalent to `EXCEPT` in SQL.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- *
- * @group typedrel
- * @since 2.0.0
- */
- def subtract(other: Dataset[T]): Dataset[T] = except(other)
-
- /**
* Returns a new [[Dataset]] by sampling a fraction of rows.
*
* @param withReplacement Sample with replacement or not.
@@ -1756,7 +1748,7 @@ class Dataset[T] private[sql](
outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
}
- val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+ val row = groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
// Pivot the data so each summary is one row
row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index f0f96825e2..8bb75bf2bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -190,7 +190,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
- def reduce(f: (V, V) => V): Dataset[(K, V)] = {
+ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder)
@@ -203,15 +203,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
- def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
- reduce(f.call _)
+ def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
+ reduceGroups(f.call _)
}
- // This is here to prevent us from adding overloads that would be ambiguous.
- @scala.annotation.varargs
- private def agg(exprs: Column*): DataFrame =
- groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
-
private def withEncoder(c: Column): Column = c match {
case tc: TypedColumn[_, _] =>
tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 3bff129ae2..18f17a85a9 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -204,7 +204,7 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
- Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
+ Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
@Override
public String call(String v1, String v2) throws Exception {
return v1 + v2;
@@ -300,7 +300,7 @@ public class JavaDatasetSuite implements Serializable {
Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo", "abc", "abc", "xyz"),
unioned.collectAsList());
- Dataset<String> subtracted = ds.subtract(ds2);
+ Dataset<String> subtracted = ds.except(ds2);
Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index c2434e46f7..351b03b38b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -105,10 +105,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
Row("a") :: Nil)
}
- test("alias") {
+ test("alias and name") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
assert(df.select(df("a").as("b")).columns.head === "b")
assert(df.select(df("a").alias("b")).columns.head === "b")
+ assert(df.select(df("a").name("b")).columns.head === "b")
}
test("as propagates metadata") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 677f84eb60..0bcc512d71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -305,7 +305,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, reduce") {
val ds = Seq("abc", "xyz", "hello").toDS()
- val agged = ds.groupByKey(_.length).reduce(_ + _)
+ val agged = ds.groupByKey(_.length).reduceGroups(_ + _)
checkDataset(
agged,