aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-03 10:34:56 -0800
committerReynold Xin <rxin@databricks.com>2015-02-03 10:34:56 -0800
commit4204a1271d5bff4dd64f46eed9ee80b30081f9dc (patch)
tree966ff6c5e8a0065b73359207ab6b61038668758c
parentf7948f3f5718b7c4a2d35634815670c4cbbe70fd (diff)
downloadspark-4204a1271d5bff4dd64f46eed9ee80b30081f9dc.tar.gz
spark-4204a1271d5bff4dd64f46eed9ee80b30081f9dc.tar.bz2
spark-4204a1271d5bff4dd64f46eed9ee80b30081f9dc.zip
[SQL] DataFrame API update
1. Added Java-friendly version of the expression operators (i.e. gt, geq) 2. Added JavaDoc for most operators 3. Simplified expression operators by having only one version of the function (that accepts Any). Previously we had two methods for each expression operator, one accepting Any and another accepting Column. 4. agg function now accepts varargs of (String, String). Author: Reynold Xin <rxin@databricks.com> Closes #4332 from rxin/df-update and squashes the following commits: ab0aa69 [Reynold Xin] Added Java friendly expression methods. Added JavaDoc. For each expression operator, have only one version of the function (that accepts Any). Previously we had two methods for each expression operator, one accepting Any and another accepting Column. 576d07a [Reynold Xin] random commit.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala439
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala87
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala10
6 files changed, 376 insertions, 270 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 0d6055ff23..4aa37219e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -65,7 +65,15 @@ trait Column extends DataFrame {
*/
def isComputable: Boolean
- private def constructColumn(other: Column)(newExpr: Expression): Column = {
+ private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
+ val plan = Project(Seq(expr match {
+ case named: NamedExpression => named
+ case unnamed: Expression => Alias(unnamed, "col")()
+ }), baseCol.plan)
+ Column(baseCol.sqlContext, plan, expr)
+ }
+
+ private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
// Removes all the top level projection and subquery so we can get to the underlying plan.
@tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
case Project(_, child) => stripProject(child)
@@ -73,392 +81,423 @@ trait Column extends DataFrame {
case _ => p
}
- def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val plan = Project(Seq(expr match {
- case named: NamedExpression => named
- case unnamed: Expression => Alias(unnamed, "col")()
- }), baseCol.plan)
- Column(baseCol.sqlContext, plan, expr)
- }
-
- (this, other) match {
+ (this, lit(otherValue)) match {
case (left: ComputableColumn, right: ComputableColumn) =>
if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
- computableCol(right, newExpr)
+ computableCol(right, newExpr(right))
} else {
- Column(newExpr)
+ Column(newExpr(right))
}
- case (left: ComputableColumn, _) => computableCol(left, newExpr)
- case (_, right: ComputableColumn) => computableCol(right, newExpr)
- case (_, _) => Column(newExpr)
+ case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
+ case (_, right: ComputableColumn) => computableCol(right, newExpr(right))
+ case (_, right) => Column(newExpr(right))
+ }
+ }
+
+ /** Creates a column based on the given expression. */
+ private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = {
+ this match {
+ case c: ComputableColumn if computable => computableCol(c, newExpr)
+ case _ => Column(newExpr)
}
}
/**
* Unary minus, i.e. negate the expression.
* {{{
- * // Select the amount column and negates all values.
+ * // Scala: select the amount column and negates all values.
* df.select( -df("amount") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.select( negate(col("amount") );
* }}}
*/
- def unary_- : Column = constructColumn(null) { UnaryMinus(expr) }
+ def unary_- : Column = exprToColumn(UnaryMinus(expr))
/**
* Bitwise NOT.
* {{{
- * // Select the flags column and negate every bit.
+ * // Scala: select the flags column and negate every bit.
* df.select( ~df("flags") )
* }}}
*/
- def unary_~ : Column = constructColumn(null) { BitwiseNot(expr) }
+ def unary_~ : Column = exprToColumn(BitwiseNot(expr))
/**
* Inversion of boolean expression, i.e. NOT.
* {{
- * // Select rows that are not active (isActive === false)
- * df.select( !df("isActive") )
+ * // Scala: select rows that are not active (isActive === false)
+ * df.filter( !df("isActive") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( not(df.col("isActive")) );
* }}
*/
- def unary_! : Column = constructColumn(null) { Not(expr) }
+ def unary_! : Column = exprToColumn(Not(expr))
/**
- * Equality test with an expression.
+ * Equality test.
* {{{
- * // The following two both select rows in which colA equals colB.
- * df.select( df("colA") === df("colB") )
- * df.select( df("colA".equalTo(df("colB")) )
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
- def === (other: Column): Column = constructColumn(other) {
- EqualTo(expr, other.expr)
+ def === (other: Any): Column = constructColumn(other) { o =>
+ EqualTo(expr, o.expr)
}
/**
- * Equality test with a literal value.
- * {{{
- * // The following two both select rows in which colA is "Zaharia".
- * df.select( df("colA") === "Zaharia")
- * df.select( df("colA".equalTo("Zaharia") )
- * }}}
- */
- def === (literal: Any): Column = this === lit(literal)
-
- /**
- * Equality test with an expression.
- * {{{
- * // The following two both select rows in which colA equals colB.
- * df.select( df("colA") === df("colB") )
- * df.select( df("colA".equalTo(df("colB")) )
- * }}}
- */
- def equalTo(other: Column): Column = this === other
-
- /**
- * Equality test with a literal value.
+ * Equality test.
* {{{
- * // The following two both select rows in which colA is "Zaharia".
- * df.select( df("colA") === "Zaharia")
- * df.select( df("colA".equalTo("Zaharia") )
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
- def equalTo(literal: Any): Column = this === literal
+ def equalTo(other: Any): Column = this === other
/**
- * Inequality test with an expression.
+ * Inequality test.
* {{{
- * // The following two both select rows in which colA does not equal colB.
+ * // Scala:
* df.select( df("colA") !== df("colB") )
* df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( not(col("colA").equalTo(col("colB"))) );
* }}}
*/
- def !== (other: Column): Column = constructColumn(other) {
- Not(EqualTo(expr, other.expr))
+ def !== (other: Any): Column = constructColumn(other) { o =>
+ Not(EqualTo(expr, o.expr))
}
/**
- * Inequality test with a literal value.
- * {{{
- * // The following two both select rows in which colA does not equal equal 15.
- * df.select( df("colA") !== 15 )
- * df.select( !(df("colA") === 15) )
- * }}}
- */
- def !== (literal: Any): Column = this !== lit(literal)
-
- /**
- * Greater than an expression.
+ * Greater than.
* {{{
- * // The following selects people older than 21.
- * people.select( people("age") > Literal(21) )
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > 21 )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * people.select( people("age").gt(21) );
* }}}
*/
- def > (other: Column): Column = constructColumn(other) {
- GreaterThan(expr, other.expr)
+ def > (other: Any): Column = constructColumn(other) { o =>
+ GreaterThan(expr, o.expr)
}
/**
- * Greater than a literal value.
+ * Greater than.
* {{{
- * // The following selects people older than 21.
- * people.select( people("age") > 21 )
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > lit(21) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * people.select( people("age").gt(21) );
* }}}
*/
- def > (literal: Any): Column = this > lit(literal)
+ def gt(other: Any): Column = this > other
/**
- * Less than an expression.
+ * Less than.
* {{{
- * // The following selects people younger than 21.
- * people.select( people("age") < Literal(21) )
+ * // Scala: The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people("age").lt(21) );
* }}}
*/
- def < (other: Column): Column = constructColumn(other) {
- LessThan(expr, other.expr)
+ def < (other: Any): Column = constructColumn(other) { o =>
+ LessThan(expr, o.expr)
}
/**
- * Less than a literal value.
+ * Less than.
* {{{
- * // The following selects people younger than 21.
+ * // Scala: The following selects people younger than 21.
* people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people("age").lt(21) );
* }}}
*/
- def < (literal: Any): Column = this < lit(literal)
+ def lt(other: Any): Column = this < other
/**
- * Less than or equal to an expression.
+ * Less than or equal to.
* {{{
- * // The following selects people age 21 or younger than 21.
- * people.select( people("age") <= Literal(21) )
+ * // Scala: The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people("age").leq(21) );
* }}}
*/
- def <= (other: Column): Column = constructColumn(other) {
- LessThanOrEqual(expr, other.expr)
+ def <= (other: Any): Column = constructColumn(other) { o =>
+ LessThanOrEqual(expr, o.expr)
}
/**
- * Less than or equal to a literal value.
+ * Less than or equal to.
* {{{
- * // The following selects people age 21 or younger than 21.
+ * // Scala: The following selects people age 21 or younger than 21.
* people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people("age").leq(21) );
* }}}
*/
- def <= (literal: Any): Column = this <= lit(literal)
+ def leq(other: Any): Column = this <= other
/**
* Greater than or equal to an expression.
* {{{
- * // The following selects people age 21 or older than 21.
- * people.select( people("age") >= Literal(21) )
+ * // Scala: The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people("age").geq(21) )
* }}}
*/
- def >= (other: Column): Column = constructColumn(other) {
- GreaterThanOrEqual(expr, other.expr)
+ def >= (other: Any): Column = constructColumn(other) { o =>
+ GreaterThanOrEqual(expr, o.expr)
}
/**
- * Greater than or equal to a literal value.
+ * Greater than or equal to an expression.
* {{{
- * // The following selects people age 21 or older than 21.
+ * // Scala: The following selects people age 21 or older than 21.
* people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people("age").geq(21) )
* }}}
*/
- def >= (literal: Any): Column = this >= lit(literal)
+ def geq(other: Any): Column = this >= other
/**
- * Equality test with an expression that is safe for null values.
+ * Equality test that is safe for null values.
*/
- def <=> (other: Column): Column = constructColumn(other) {
- other match {
- case null => EqualNullSafe(expr, lit(null).expr)
- case _ => EqualNullSafe(expr, other.expr)
- }
+ def <=> (other: Any): Column = constructColumn(other) { o =>
+ EqualNullSafe(expr, o.expr)
}
/**
- * Equality test with a literal value that is safe for null values.
+ * Equality test that is safe for null values.
*/
- def <=> (literal: Any): Column = this <=> lit(literal)
+ def eqNullSafe(other: Any): Column = this <=> other
/**
* True if the current expression is null.
*/
- def isNull: Column = constructColumn(null) { IsNull(expr) }
+ def isNull: Column = exprToColumn(IsNull(expr))
/**
* True if the current expression is NOT null.
*/
- def isNotNull: Column = constructColumn(null) { IsNotNull(expr) }
+ def isNotNull: Column = exprToColumn(IsNotNull(expr))
/**
- * Boolean OR with an expression.
+ * Boolean OR.
* {{{
- * // The following selects people that are in school or employed.
- * people.select( people("inSchool") || people("isEmployed") )
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
*/
- def || (other: Column): Column = constructColumn(other) {
- Or(expr, other.expr)
+ def || (other: Any): Column = constructColumn(other) { o =>
+ Or(expr, o.expr)
}
/**
- * Boolean OR with a literal value.
+ * Boolean OR.
* {{{
- * // The following selects everything.
- * people.select( people("inSchool") || true )
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
*/
- def || (literal: Boolean): Column = this || lit(literal)
+ def or(other: Column): Column = this || other
/**
- * Boolean AND with an expression.
+ * Boolean AND.
* {{{
- * // The following selects people that are in school and employed at the same time.
+ * // Scala: The following selects people that are in school and employed at the same time.
* people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people("inSchool").and(people("isEmployed")) );
* }}}
*/
- def && (other: Column): Column = constructColumn(other) {
- And(expr, other.expr)
+ def && (other: Any): Column = constructColumn(other) { o =>
+ And(expr, o.expr)
}
/**
- * Boolean AND with a literal value.
+ * Boolean AND.
* {{{
- * // The following selects people that are in school.
- * people.select( people("inSchool") && true )
+ * // Scala: The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people("inSchool").and(people("isEmployed")) );
* }}}
*/
- def && (literal: Boolean): Column = this && lit(literal)
+ def and(other: Column): Column = this && other
/**
- * Bitwise AND with an expression.
+ * Bitwise AND.
*/
- def & (other: Column): Column = constructColumn(other) {
- BitwiseAnd(expr, other.expr)
+ def & (other: Any): Column = constructColumn(other) { o =>
+ BitwiseAnd(expr, o.expr)
}
/**
- * Bitwise AND with a literal value.
- */
- def & (literal: Any): Column = this & lit(literal)
-
- /**
* Bitwise OR with an expression.
*/
- def | (other: Column): Column = constructColumn(other) {
- BitwiseOr(expr, other.expr)
+ def | (other: Any): Column = constructColumn(other) { o =>
+ BitwiseOr(expr, o.expr)
}
/**
- * Bitwise OR with a literal value.
- */
- def | (literal: Any): Column = this | lit(literal)
-
- /**
* Bitwise XOR with an expression.
*/
- def ^ (other: Column): Column = constructColumn(other) {
- BitwiseXor(expr, other.expr)
+ def ^ (other: Any): Column = constructColumn(other) { o =>
+ BitwiseXor(expr, o.expr)
}
/**
- * Bitwise XOR with a literal value.
- */
- def ^ (literal: Any): Column = this ^ lit(literal)
-
- /**
* Sum of this expression and another expression.
* {{{
- * // The following selects the sum of a person's height and weight.
+ * // Scala: The following selects the sum of a person's height and weight.
* people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").plus(people("weight")) );
* }}}
*/
- def + (other: Column): Column = constructColumn(other) {
- Add(expr, other.expr)
+ def + (other: Any): Column = constructColumn(other) { o =>
+ Add(expr, o.expr)
}
/**
* Sum of this expression and another expression.
* {{{
- * // The following selects the sum of a person's height and 10.
- * people.select( people("height") + 10 )
+ * // Scala: The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").plus(people("weight")) );
* }}}
*/
- def + (literal: Any): Column = this + lit(literal)
+ def plus(other: Any): Column = this + other
/**
* Subtraction. Subtract the other expression from this expression.
* {{{
- * // The following selects the difference between people's height and their weight.
+ * // Scala: The following selects the difference between people's height and their weight.
* people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").minus(people("weight")) );
* }}}
*/
- def - (other: Column): Column = constructColumn(other) {
- Subtract(expr, other.expr)
+ def - (other: Any): Column = constructColumn(other) { o =>
+ Subtract(expr, o.expr)
}
/**
- * Subtraction. Subtract a literal value from this expression.
+ * Subtraction. Subtract the other expression from this expression.
* {{{
- * // The following selects a person's height and subtract it by 10.
- * people.select( people("height") - 10 )
+ * // Scala: The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").minus(people("weight")) );
* }}}
*/
- def - (literal: Any): Column = this - lit(literal)
+ def minus(other: Any): Column = this - other
/**
* Multiplication of this expression and another expression.
* {{{
- * // The following multiplies a person's height by their weight.
+ * // Scala: The following multiplies a person's height by their weight.
* people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").multiply(people("weight")) );
* }}}
*/
- def * (other: Column): Column = constructColumn(other) {
- Multiply(expr, other.expr)
+ def * (other: Any): Column = constructColumn(other) { o =>
+ Multiply(expr, o.expr)
}
/**
- * Multiplication this expression and a literal value.
+ * Multiplication of this expression and another expression.
* {{{
- * // The following multiplies a person's height by 10.
- * people.select( people("height") * 10 )
+ * // Scala: The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").multiply(people("weight")) );
* }}}
*/
- def * (literal: Any): Column = this * lit(literal)
+ def multiply(other: Any): Column = this * other
/**
* Division this expression by another expression.
* {{{
- * // The following divides a person's height by their weight.
+ * // Scala: The following divides a person's height by their weight.
* people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").divide(people("weight")) );
* }}}
*/
- def / (other: Column): Column = constructColumn(other) {
- Divide(expr, other.expr)
+ def / (other: Any): Column = constructColumn(other) { o =>
+ Divide(expr, o.expr)
}
/**
- * Division this expression by a literal value.
+ * Division this expression by another expression.
* {{{
- * // The following divides a person's height by 10.
- * people.select( people("height") / 10 )
+ * // Scala: The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").divide(people("weight")) );
* }}}
*/
- def / (literal: Any): Column = this / lit(literal)
+ def divide(other: Any): Column = this / other
/**
* Modulo (a.k.a. remainder) expression.
*/
- def % (other: Column): Column = constructColumn(other) {
- Remainder(expr, other.expr)
+ def % (other: Any): Column = constructColumn(other) { o =>
+ Remainder(expr, o.expr)
}
/**
* Modulo (a.k.a. remainder) expression.
*/
- def % (literal: Any): Column = this % lit(literal)
-
+ def mod(other: Any): Column = this % other
/**
* A boolean expression that is evaluated to true if the value of this expression is contained
@@ -469,27 +508,19 @@ trait Column extends DataFrame {
new IncomputableColumn(In(expr, list.map(_.expr)))
}
- def like(literal: String): Column = constructColumn(null) {
- Like(expr, lit(literal).expr)
- }
+ def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr))
- def rlike(literal: String): Column = constructColumn(null) {
- RLike(expr, lit(literal).expr)
- }
+ def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr))
/**
* An expression that gets an item at position `ordinal` out of an array.
*/
- def getItem(ordinal: Int): Column = constructColumn(null) {
- GetItem(expr, Literal(ordinal))
- }
+ def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal)))
/**
* An expression that gets a field by name in a [[StructField]].
*/
- def getField(fieldName: String): Column = constructColumn(null) {
- GetField(expr, fieldName)
- }
+ def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName))
/**
* An expression that returns a substring.
@@ -507,20 +538,18 @@ trait Column extends DataFrame {
*/
def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
- def contains(other: Column): Column = constructColumn(other) {
- Contains(expr, other.expr)
+ def contains(other: Any): Column = constructColumn(other) { o =>
+ Contains(expr, o.expr)
}
- def contains(literal: Any): Column = this.contains(lit(literal))
-
- def startsWith(other: Column): Column = constructColumn(other) {
- StartsWith(expr, other.expr)
+ def startsWith(other: Column): Column = constructColumn(other) { o =>
+ StartsWith(expr, o.expr)
}
def startsWith(literal: String): Column = this.startsWith(lit(literal))
- def endsWith(other: Column): Column = constructColumn(other) {
- EndsWith(expr, other.expr)
+ def endsWith(other: Column): Column = constructColumn(other) { o =>
+ EndsWith(expr, o.expr)
}
def endsWith(literal: String): Column = this.endsWith(lit(literal))
@@ -532,7 +561,7 @@ trait Column extends DataFrame {
* df.select($"colA".as("colB"))
* }}}
*/
- override def as(alias: String): Column = constructColumn(null) { Alias(expr, alias)() }
+ override def as(alias: String): Column = exprToColumn(Alias(expr, alias)())
/**
* Casts the column to a different data type.
@@ -545,7 +574,7 @@ trait Column extends DataFrame {
* df.select(df("colA").cast("int"))
* }}}
*/
- def cast(to: DataType): Column = constructColumn(null) { Cast(expr, to) }
+ def cast(to: DataType): Column = exprToColumn(Cast(expr, to))
/**
* Casts the column to a different data type, using the canonical string representation
@@ -556,7 +585,7 @@ trait Column extends DataFrame {
* df.select(df("colA").cast("int"))
* }}}
*/
- def cast(to: String): Column = constructColumn(null) {
+ def cast(to: String): Column = exprToColumn(
Cast(expr, to.toLowerCase match {
case "string" => StringType
case "boolean" => BooleanType
@@ -571,11 +600,11 @@ trait Column extends DataFrame {
case "timestamp" => TimestampType
case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
})
- }
+ )
- def desc: Column = constructColumn(null) { SortOrder(expr, Descending) }
+ def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false)
- def asc: Column = constructColumn(null) { SortOrder(expr, Ascending) }
+ def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 5920852e8c..f3bc07ae52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
private[sql] object DataFrame {
@@ -138,7 +139,13 @@ trait DataFrame extends RDDApi[Row] {
* a full outer join between `df1` and `df2`.
*
* {{{
+ * // Scala:
+ * import org.apache.spark.sql.dsl._
* df1.join(df2, "outer", $"df1Key" === $"df2Key")
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df1.join(df2, "outer", col("df1Key") === col("df2Key"));
* }}}
*
* @param right Right side of the join.
@@ -185,7 +192,12 @@ trait DataFrame extends RDDApi[Row] {
/**
* Selects column based on the column name and return it as a [[Column]].
*/
- def apply(colName: String): Column
+ def apply(colName: String): Column = col(colName)
+
+ /**
+ * Selects column based on the column name and return it as a [[Column]].
+ */
+ def col(colName: String): Column
/**
* Selects a set of expressions, wrapped in a Product.
@@ -297,24 +309,41 @@ trait DataFrame extends RDDApi[Row] {
def groupBy(col1: String, cols: String*): GroupedDataFrame
/**
- * Aggregates on the entire [[DataFrame]] without groups.
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ groupBy().agg(aggExpr, aggExprs :_*)
+ }
+
+ /**
+ * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- def agg(exprs: Map[String, String]): DataFrame
+ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
- * Aggregates on the entire [[DataFrame]] without groups.
+ * (Java-specific) Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- def agg(exprs: java.util.Map[String, String]): DataFrame
+ def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
* Aggregates on the entire [[DataFrame]] without groups.
@@ -325,7 +354,7 @@ trait DataFrame extends RDDApi[Row] {
* }}
*/
@scala.annotation.varargs
- def agg(expr: Column, exprs: Column*): DataFrame
+ def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
/**
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
@@ -366,7 +395,9 @@ trait DataFrame extends RDDApi[Row] {
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
*/
- def sample(withReplacement: Boolean, fraction: Double): DataFrame
+ def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 49fd131534..0b0623dc1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan}
import org.apache.spark.sql.types.{NumericType, StructType}
-import org.apache.spark.util.Utils
/**
@@ -148,7 +147,7 @@ private[sql] class DataFrameImpl protected[sql](
sort(sortExpr, sortExprs :_*)
}
- override def apply(colName: String): Column = colName match {
+ override def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(schema.fieldNames.map(resolve)))
case _ =>
@@ -201,18 +200,6 @@ private[sql] class DataFrameImpl protected[sql](
new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
}
- override def agg(exprs: Map[String, String]): DataFrame = {
- groupBy().agg(exprs)
- }
-
- override def agg(exprs: java.util.Map[String, String]): DataFrame = {
- agg(exprs.toMap)
- }
-
- override def agg(expr: Column, exprs: Column*): DataFrame = {
- groupBy().agg(expr, exprs :_*)
- }
-
override def limit(n: Int): DataFrame = {
Limit(Literal(n), logicalPlan)
}
@@ -233,10 +220,6 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}
- override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
- sample(withReplacement, fraction, Utils.random.nextLong)
- }
-
/////////////////////////////////////////////////////////////////////////////
override def addColumn(colName: String, col: Column): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index b4279a32ff..71365c776d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -36,21 +36,6 @@ object Dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
- // /**
- // * An implicit conversion that turns a RDD of product into a [[DataFrame]].
- // *
- // * This method requires an implicit SQLContext in scope. For example:
- // * {{{
- // * implicit val sqlContext: SQLContext = ...
- // * val rdd: RDD[(Int, String)] = ...
- // * rdd.toDataFrame // triggers the implicit here
- // * }}}
- // */
- // implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
- // : DataFrame = {
- // context.createDataFrame(rdd)
- // }
-
/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
@@ -72,10 +57,16 @@ object Dsl {
/**
* Creates a [[Column]] of literal value.
+ *
+ * The passed in object is returned directly if it is already a [[Column]].
+ * If the object is a Scala Symbol, it is converted into a [[Column]] also.
+ * Otherwise, a new [[Column]] is created to represent the literal value.
*/
def lit(literal: Any): Column = {
- if (literal.isInstanceOf[Symbol]) {
- return new ColumnName(literal.asInstanceOf[Symbol].name)
+ literal match {
+ case c: Column => return c
+ case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
+ case _ => // continue
}
val literalExpr = literal match {
@@ -100,27 +91,82 @@ object Dsl {
Column(literalExpr)
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /** Aggregate function: returns the sum of all values in the expression. */
def sum(e: Column): Column = Sum(e.expr)
+
+ /** Aggregate function: returns the sum of distinct values in the expression. */
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+
+ /** Aggregate function: returns the number of items in a group. */
def count(e: Column): Column = Count(e.expr)
+ /** Aggregate function: returns the number of distinct items in a group. */
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
- def approxCountDistinct(e: Column, rsd: Double): Column =
- ApproxCountDistinct(e.expr, rsd)
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
+ def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)
+
+ /** Aggregate function: returns the average of the values in a group. */
def avg(e: Column): Column = Average(e.expr)
+
+ /** Aggregate function: returns the first value in a group. */
def first(e: Column): Column = First(e.expr)
+
+ /** Aggregate function: returns the last value in a group. */
def last(e: Column): Column = Last(e.expr)
+
+ /** Aggregate function: returns the minimum value of the expression in a group. */
def min(e: Column): Column = Min(e.expr)
+
+ /** Aggregate function: returns the maximum value of the expression in a group. */
def max(e: Column): Column = Max(e.expr)
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Select the amount column and negates all values.
+ * // Scala:
+ * df.select( -df("amount") )
+ *
+ * // Java:
+ * df.select( negate(df.col("amount")) );
+ * }}}
+ */
+ def negate(e: Column): Column = -e
+
+ /**
+ * Inversion of boolean expression, i.e. NOT.
+ * {{
+ * // Scala: select rows that are not active (isActive === false)
+ * df.filter( !df("isActive") )
+ *
+ * // Java:
+ * df.filter( not(df.col("isActive")) );
+ * }}
+ */
+ def not(e: Column): Column = !e
+
+ /** Converts a string expression to upper case. */
def upper(e: Column): Column = Upper(e.expr)
+
+ /** Converts a string exprsesion to lower case. */
def lower(e: Column): Column = Lower(e.expr)
+
+ /** Computes the square root of the specified float value. */
def sqrt(e: Column): Column = Sqrt(e.expr)
+
+ /** Computes the absolutle value. */
def abs(e: Column): Column = Abs(e.expr)
/**
@@ -131,6 +177,9 @@ object Dsl {
cols.toList.toSeq
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
// scalastyle:off
/* Use the following code to generate:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
index 6d0f3e8ce3..7963cb0312 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql
-import java.util.{List => JList}
-
import scala.language.implicitConversions
import scala.collection.JavaConversions._
@@ -59,15 +57,32 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
- * [[DataFrame]] will also contain the grouping columns.
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ agg((aggExpr +: aggExprs).toMap)
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* df.groupBy("department").agg(Map(
- * "age" -> "max"
- * "sum" -> "expense"
+ * "age" -> "max",
+ * "expense" -> "sum"
* ))
* }}}
*/
@@ -79,16 +94,17 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
- * [[DataFrame]] will also contain the grouping columns.
+ * (Java-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
- * df.groupBy("department").agg(Map(
- * "age" -> "max"
- * "sum" -> "expense"
- * ))
+ * import com.google.common.collect.ImmutableMap;
+ * df.groupBy("department").agg(ImmutableMap.<String, String>builder()
+ * .put("age", "max")
+ * .put("expense", "sum")
+ * .build());
* }}}
*/
def agg(exprs: java.util.Map[String, String]): DataFrame = {
@@ -103,8 +119,14 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
+ *
+ * // Scala:
* import org.apache.spark.sql.dsl._
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense")));
* }}}
*/
@scala.annotation.varargs
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 9b051de68f..ba5c7355b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -72,7 +72,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
- override def apply(colName: String): Column = err()
+ override def col(colName: String): Column = err()
override def apply(projection: Product): DataFrame = err()
@@ -90,12 +90,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def groupBy(col1: String, cols: String*): GroupedDataFrame = err()
- override def agg(exprs: Map[String, String]): DataFrame = err()
-
- override def agg(exprs: java.util.Map[String, String]): DataFrame = err()
-
- override def agg(expr: Column, exprs: Column*): DataFrame = err()
-
override def limit(n: Int): DataFrame = err()
override def unionAll(other: DataFrame): DataFrame = err()
@@ -106,8 +100,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
- override def sample(withReplacement: Boolean, fraction: Double): DataFrame = err()
-
/////////////////////////////////////////////////////////////////////////////
override def addColumn(colName: String, col: Column): DataFrame = err()