aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-29 17:24:00 -0800
committerReynold Xin <rxin@databricks.com>2015-01-29 17:24:00 -0800
commitce9c43ba8ca1ba6507fd3bf3c647ab7396d33653 (patch)
tree2158cd36fc258300601aa1c9b6768a47f188e810 /sql
parentd2071e8f45e74117f78a42770b0c610cb98e5075 (diff)
downloadspark-ce9c43ba8ca1ba6507fd3bf3c647ab7396d33653.tar.gz
spark-ce9c43ba8ca1ba6507fd3bf3c647ab7396d33653.tar.bz2
spark-ce9c43ba8ca1ba6507fd3bf3c647ab7396d33653.zip
[SQL] DataFrame API improvements
1. Added Dsl.column in case Dsl.col is shadowed. 2. Allow using String to specify the target data type in cast. 3. Support sorting on multiple columns using column names. 4. Added Java API test file. Author: Reynold Xin <rxin@databricks.com> Closes #4280 from rxin/dsl1 and squashes the following commits: 33ecb7a [Reynold Xin] Add the Java test. d06540a [Reynold Xin] [SQL] DataFrame API improvements.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api.scala11
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java120
6 files changed, 209 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ca50fd6f05..68c9cb0c02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -56,7 +56,7 @@ object Column {
class Column(
sqlContext: Option[SQLContext],
plan: Option[LogicalPlan],
- val expr: Expression)
+ protected[sql] val expr: Expression)
extends DataFrame(sqlContext, plan) with ExpressionApi {
/** Turns a Catalyst expression into a `Column`. */
@@ -437,9 +437,7 @@ class Column(
override def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
- * An expression that gets an
- * @param ordinal
- * @return
+ * An expression that gets an item at position `ordinal` out of an array.
*/
override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
@@ -490,11 +488,38 @@ class Column(
* {{{
* // Casts colA to IntegerType.
* import org.apache.spark.sql.types.IntegerType
- * df.select(df("colA").as(IntegerType))
+ * df.select(df("colA").cast(IntegerType))
+ *
+ * // equivalent to
+ * df.select(df("colA").cast("int"))
* }}}
*/
override def cast(to: DataType): Column = Cast(expr, to)
+ /**
+ * Casts the column to a different data type, using the canonical string representation
+ * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`,
+ * `float`, `double`, `decimal`, `date`, `timestamp`.
+ * {{{
+ * // Casts colA to integer.
+ * df.select(df("colA").cast("int"))
+ * }}}
+ */
+ override def cast(to: String): Column = Cast(expr, to.toLowerCase match {
+ case "string" => StringType
+ case "boolean" => BooleanType
+ case "byte" => ByteType
+ case "short" => ShortType
+ case "int" => IntegerType
+ case "long" => LongType
+ case "float" => FloatType
+ case "double" => DoubleType
+ case "decimal" => DecimalType.Unlimited
+ case "date" => DateType
+ case "timestamp" => TimestampType
+ case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
+ })
+
override def desc: Column = SortOrder(expr, Descending)
override def asc: Column = SortOrder(expr, Ascending)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 94c13a5c26..1ff25adcf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -208,7 +208,7 @@ class DataFrame protected[sql](
}
/**
- * Returns a new [[DataFrame]] sorted by the specified column, in ascending column.
+ * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* df.sort("sortcol")
@@ -216,8 +216,9 @@ class DataFrame protected[sql](
* df.sort($"sortcol".asc)
* }}}
*/
- override def sort(colName: String): DataFrame = {
- Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan)
+ @scala.annotation.varargs
+ override def sort(sortCol: String, sortCols: String*): DataFrame = {
+ orderBy(apply(sortCol), sortCols.map(apply) :_*)
}
/**
@@ -244,6 +245,15 @@ class DataFrame protected[sql](
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
+ override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
+ sort(sortCol, sortCols :_*)
+ }
+
+ /**
+ * Returns a new [[DataFrame]] sorted by the given expressions.
+ * This is an alias of the `sort` function.
+ */
+ @scala.annotation.varargs
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
sort(sortExpr, sortExprs :_*)
}
@@ -405,6 +415,16 @@ class DataFrame protected[sql](
* Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(Map("age" -> "max", "salary" -> "avg"))
+ * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}
+ */
+ override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap)
+
+ /**
+ * Aggregates on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(max($"age"), avg($"salary"))
* df.groupBy().agg(max($"age"), avg($"salary"))
* }}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index f47ff995e9..75717e7cd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -63,6 +63,11 @@ object Dsl {
def col(colName: String): Column = new Column(colName)
/**
+ * Returns a [[Column]] based on the given column name. Alias of [[col]].
+ */
+ def column(colName: String): Column = new Column(colName)
+
+ /**
* Creates a [[Column]] of literal value.
*/
def lit(literal: Any): Column = {
@@ -96,6 +101,7 @@ object Dsl {
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
def count(e: Column): Column = Count(e.expr)
+ @scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
index 1f1e9bd989..1c948cbbfe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -58,7 +58,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods.
+ * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
+ * [[DataFrame]] will also contain the grouping columns.
+ *
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@@ -76,7 +78,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods.
+ * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
+ * [[DataFrame]] will also contain the grouping columns.
+ *
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@@ -91,12 +95,15 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a series of aggregate columns.
- * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]].
+ * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
+ * class, the resulting [[DataFrame]] won't automatically include the grouping columns.
+ *
+ * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
+ *
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* import org.apache.spark.sql.dsl._
- * df.groupBy("department").agg(max($"age"), sum($"expense"))
+ * df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
* }}}
*/
@scala.annotation.varargs
@@ -109,31 +116,39 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
- /** Count the number of rows for each group. */
+ /**
+ * Count the number of rows for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ */
override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def mean(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the max value for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def max(): DataFrame = aggregateNumericColumns(Max)
/**
* Compute the mean value for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def avg(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the min value for each numeric column for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def min(): DataFrame = aggregateNumericColumns(Min)
/**
* Compute the sum for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def sum(): DataFrame = aggregateNumericColumns(Sum)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
index 59634082f6..eb0eb3f325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -113,16 +113,22 @@ private[sql] trait DataFrameSpecificApi {
def agg(exprs: Map[String, String]): DataFrame
+ def agg(exprs: java.util.Map[String, String]): DataFrame
+
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame
- def sort(colName: String): DataFrame
+ @scala.annotation.varargs
+ def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def sort(sortCol: String, sortCols: String*): DataFrame
@scala.annotation.varargs
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
@scala.annotation.varargs
- def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+ def orderBy(sortCol: String, sortCols: String*): DataFrame
def join(right: DataFrame): DataFrame
@@ -257,6 +263,7 @@ private[sql] trait ExpressionApi {
def getField(fieldName: String): Column
def cast(to: DataType): Column
+ def cast(to: String): Column
def asc: Column
def desc: Column
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
new file mode 100644
index 0000000000..639436368c
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import com.google.common.collect.ImmutableMap;
+
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.types.DataTypes;
+
+import static org.apache.spark.sql.Dsl.*;
+
+/**
+ * This test doesn't actually run anything. It is here to check the API compatibility for Java.
+ */
+public class JavaDsl {
+
+ public static void testDataFrame(final DataFrame df) {
+ DataFrame df1 = df.select("colA");
+ df1 = df.select("colA", "colB");
+
+ df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1));
+
+ df1 = df.filter(col("colA"));
+
+ java.util.Map<String, String> aggExprs = ImmutableMap.<String, String>builder()
+ .put("colA", "sum")
+ .put("colB", "avg")
+ .build();
+
+ df1 = df.agg(aggExprs);
+
+ df1 = df.groupBy("groupCol").agg(aggExprs);
+
+ df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer");
+
+ df.orderBy("colA");
+ df.orderBy("colA", "colB", "colC");
+ df.orderBy(col("colA").desc());
+ df.orderBy(col("colA").desc(), col("colB").asc());
+
+ df.sort("colA");
+ df.sort("colA", "colB", "colC");
+ df.sort(col("colA").desc());
+ df.sort(col("colA").desc(), col("colB").asc());
+
+ df.as("b");
+
+ df.limit(5);
+
+ df.unionAll(df1);
+ df.intersect(df1);
+ df.except(df1);
+
+ df.sample(true, 0.1, 234);
+
+ df.head();
+ df.head(5);
+ df.first();
+ df.count();
+ }
+
+ public static void testColumn(final Column c) {
+ c.asc();
+ c.desc();
+
+ c.endsWith("abcd");
+ c.startsWith("afgasdf");
+
+ c.like("asdf%");
+ c.rlike("wef%asdf");
+
+ c.as("newcol");
+
+ c.cast("int");
+ c.cast(DataTypes.IntegerType);
+ }
+
+ public static void testDsl() {
+ // Creating a column.
+ Column c = col("abcd");
+ Column c1 = column("abcd");
+
+ // Literals
+ Column l1 = lit(1);
+ Column l2 = lit(1.0);
+ Column l3 = lit("abcd");
+
+ // Functions
+ Column a = upper(c);
+ a = lower(c);
+ a = sqrt(c);
+ a = abs(c);
+
+ // Aggregates
+ a = min(c);
+ a = max(c);
+ a = sum(c);
+ a = sumDistinct(c);
+ a = countDistinct(c, a);
+ a = avg(c);
+ a = first(c);
+ a = last(c);
+ }
+}