aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api.scala11
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java120
6 files changed, 209 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ca50fd6f05..68c9cb0c02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -56,7 +56,7 @@ object Column {
class Column(
sqlContext: Option[SQLContext],
plan: Option[LogicalPlan],
- val expr: Expression)
+ protected[sql] val expr: Expression)
extends DataFrame(sqlContext, plan) with ExpressionApi {
/** Turns a Catalyst expression into a `Column`. */
@@ -437,9 +437,7 @@ class Column(
override def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
- * An expression that gets an
- * @param ordinal
- * @return
+ * An expression that gets an item at position `ordinal` out of an array.
*/
override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
@@ -490,11 +488,38 @@ class Column(
* {{{
* // Casts colA to IntegerType.
* import org.apache.spark.sql.types.IntegerType
- * df.select(df("colA").as(IntegerType))
+ * df.select(df("colA").cast(IntegerType))
+ *
+ * // equivalent to
+ * df.select(df("colA").cast("int"))
* }}}
*/
override def cast(to: DataType): Column = Cast(expr, to)
+ /**
+ * Casts the column to a different data type, using the canonical string representation
+ * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`,
+ * `float`, `double`, `decimal`, `date`, `timestamp`.
+ * {{{
+ * // Casts colA to integer.
+ * df.select(df("colA").cast("int"))
+ * }}}
+ */
+ override def cast(to: String): Column = Cast(expr, to.toLowerCase match {
+ case "string" => StringType
+ case "boolean" => BooleanType
+ case "byte" => ByteType
+ case "short" => ShortType
+ case "int" => IntegerType
+ case "long" => LongType
+ case "float" => FloatType
+ case "double" => DoubleType
+ case "decimal" => DecimalType.Unlimited
+ case "date" => DateType
+ case "timestamp" => TimestampType
+ case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
+ })
+
override def desc: Column = SortOrder(expr, Descending)
override def asc: Column = SortOrder(expr, Ascending)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 94c13a5c26..1ff25adcf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -208,7 +208,7 @@ class DataFrame protected[sql](
}
/**
- * Returns a new [[DataFrame]] sorted by the specified column, in ascending column.
+ * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* df.sort("sortcol")
@@ -216,8 +216,9 @@ class DataFrame protected[sql](
* df.sort($"sortcol".asc)
* }}}
*/
- override def sort(colName: String): DataFrame = {
- Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan)
+ @scala.annotation.varargs
+ override def sort(sortCol: String, sortCols: String*): DataFrame = {
+ orderBy(apply(sortCol), sortCols.map(apply) :_*)
}
/**
@@ -244,6 +245,15 @@ class DataFrame protected[sql](
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
+ override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
+ sort(sortCol, sortCols :_*)
+ }
+
+ /**
+ * Returns a new [[DataFrame]] sorted by the given expressions.
+ * This is an alias of the `sort` function.
+ */
+ @scala.annotation.varargs
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
sort(sortExpr, sortExprs :_*)
}
@@ -405,6 +415,16 @@ class DataFrame protected[sql](
* Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
+ * df.agg(Map("age" -> "max", "salary" -> "avg"))
+ * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
+ * }}
+ */
+ override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap)
+
+ /**
+ * Aggregates on the entire [[DataFrame]] without groups.
+ * {{
+ * // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(max($"age"), avg($"salary"))
* df.groupBy().agg(max($"age"), avg($"salary"))
* }}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index f47ff995e9..75717e7cd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -63,6 +63,11 @@ object Dsl {
def col(colName: String): Column = new Column(colName)
/**
+ * Returns a [[Column]] based on the given column name. Alias of [[col]].
+ */
+ def column(colName: String): Column = new Column(colName)
+
+ /**
* Creates a [[Column]] of literal value.
*/
def lit(literal: Any): Column = {
@@ -96,6 +101,7 @@ object Dsl {
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
def count(e: Column): Column = Count(e.expr)
+ @scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
index 1f1e9bd989..1c948cbbfe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -58,7 +58,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods.
+ * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
+ * [[DataFrame]] will also contain the grouping columns.
+ *
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@@ -76,7 +78,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods.
+ * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
+ * [[DataFrame]] will also contain the grouping columns.
+ *
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@@ -91,12 +95,15 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
- * Compute aggregates by specifying a series of aggregate columns.
- * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]].
+ * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
+ * class, the resulting [[DataFrame]] won't automatically include the grouping columns.
+ *
+ * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
+ *
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* import org.apache.spark.sql.dsl._
- * df.groupBy("department").agg(max($"age"), sum($"expense"))
+ * df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
* }}}
*/
@scala.annotation.varargs
@@ -109,31 +116,39 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
- /** Count the number of rows for each group. */
+ /**
+ * Count the number of rows for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
+ */
override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def mean(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the max value for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def max(): DataFrame = aggregateNumericColumns(Max)
/**
* Compute the mean value for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def avg(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the min value for each numeric column for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def min(): DataFrame = aggregateNumericColumns(Min)
/**
* Compute the sum for each numeric columns for each group.
+ * The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def sum(): DataFrame = aggregateNumericColumns(Sum)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
index 59634082f6..eb0eb3f325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -113,16 +113,22 @@ private[sql] trait DataFrameSpecificApi {
def agg(exprs: Map[String, String]): DataFrame
+ def agg(exprs: java.util.Map[String, String]): DataFrame
+
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame
- def sort(colName: String): DataFrame
+ @scala.annotation.varargs
+ def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+
+ @scala.annotation.varargs
+ def sort(sortCol: String, sortCols: String*): DataFrame
@scala.annotation.varargs
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
@scala.annotation.varargs
- def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+ def orderBy(sortCol: String, sortCols: String*): DataFrame
def join(right: DataFrame): DataFrame
@@ -257,6 +263,7 @@ private[sql] trait ExpressionApi {
def getField(fieldName: String): Column
def cast(to: DataType): Column
+ def cast(to: String): Column
def asc: Column
def desc: Column
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
new file mode 100644
index 0000000000..639436368c
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import com.google.common.collect.ImmutableMap;
+
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.types.DataTypes;
+
+import static org.apache.spark.sql.Dsl.*;
+
+/**
+ * This test doesn't actually run anything. It is here to check the API compatibility for Java.
+ */
+public class JavaDsl {
+
+ public static void testDataFrame(final DataFrame df) {
+ DataFrame df1 = df.select("colA");
+ df1 = df.select("colA", "colB");
+
+ df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1));
+
+ df1 = df.filter(col("colA"));
+
+ java.util.Map<String, String> aggExprs = ImmutableMap.<String, String>builder()
+ .put("colA", "sum")
+ .put("colB", "avg")
+ .build();
+
+ df1 = df.agg(aggExprs);
+
+ df1 = df.groupBy("groupCol").agg(aggExprs);
+
+ df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer");
+
+ df.orderBy("colA");
+ df.orderBy("colA", "colB", "colC");
+ df.orderBy(col("colA").desc());
+ df.orderBy(col("colA").desc(), col("colB").asc());
+
+ df.sort("colA");
+ df.sort("colA", "colB", "colC");
+ df.sort(col("colA").desc());
+ df.sort(col("colA").desc(), col("colB").asc());
+
+ df.as("b");
+
+ df.limit(5);
+
+ df.unionAll(df1);
+ df.intersect(df1);
+ df.except(df1);
+
+ df.sample(true, 0.1, 234);
+
+ df.head();
+ df.head(5);
+ df.first();
+ df.count();
+ }
+
+ public static void testColumn(final Column c) {
+ c.asc();
+ c.desc();
+
+ c.endsWith("abcd");
+ c.startsWith("afgasdf");
+
+ c.like("asdf%");
+ c.rlike("wef%asdf");
+
+ c.as("newcol");
+
+ c.cast("int");
+ c.cast(DataTypes.IntegerType);
+ }
+
+ public static void testDsl() {
+ // Creating a column.
+ Column c = col("abcd");
+ Column c1 = column("abcd");
+
+ // Literals
+ Column l1 = lit(1);
+ Column l2 = lit(1.0);
+ Column l3 = lit("abcd");
+
+ // Functions
+ Column a = upper(c);
+ a = lower(c);
+ a = sqrt(c);
+ a = abs(c);
+
+ // Aggregates
+ a = min(c);
+ a = max(c);
+ a = sum(c);
+ a = sumDistinct(c);
+ a = countDistinct(c, a);
+ a = avg(c);
+ a = first(c);
+ a = last(c);
+ }
+}