aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala439
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala87
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala10
6 files changed, 376 insertions, 270 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 0d6055ff23..4aa37219e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -65,7 +65,15 @@ trait Column extends DataFrame {
*/
def isComputable: Boolean
- private def constructColumn(other: Column)(newExpr: Expression): Column = {
+ private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
+ val plan = Project(Seq(expr match {
+ case named: NamedExpression => named
+ case unnamed: Expression => Alias(unnamed, "col")()
+ }), baseCol.plan)
+ Column(baseCol.sqlContext, plan, expr)
+ }
+
+ private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
// Removes all the top level projection and subquery so we can get to the underlying plan.
@tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
case Project(_, child) => stripProject(child)
@@ -73,392 +81,423 @@ trait Column extends DataFrame {
case _ => p
}
- def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val plan = Project(Seq(expr match {
- case named: NamedExpression => named
- case unnamed: Expression => Alias(unnamed, "col")()
- }), baseCol.plan)
- Column(baseCol.sqlContext, plan, expr)
- }
-
- (this, other) match {
+ (this, lit(otherValue)) match {
case (left: ComputableColumn, right: ComputableColumn) =>
if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
- computableCol(right, newExpr)
+ computableCol(right, newExpr(right))
} else {
- Column(newExpr)
+ Column(newExpr(right))
}
- case (left: ComputableColumn, _) => computableCol(left, newExpr)
- case (_, right: ComputableColumn) => computableCol(right, newExpr)
- case (_, _) => Column(newExpr)
+ case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
+ case (_, right: ComputableColumn) => computableCol(right, newExpr(right))
+ case (_, right) => Column(newExpr(right))
+ }
+ }
+
+ /** Creates a column based on the given expression. */
+ private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = {
+ this match {
+ case c: ComputableColumn if computable => computableCol(c, newExpr)
+ case _ => Column(newExpr)
}
}
/**
* Unary minus, i.e. negate the expression.
* {{{
- * // Select the amount column and negates all values.
+ * // Scala: select the amount column and negates all values.
* df.select( -df("amount") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.select( negate(col("amount") );
* }}}
*/
- def unary_- : Column = constructColumn(null) { UnaryMinus(expr) }
+ def unary_- : Column = exprToColumn(UnaryMinus(expr))
/**
* Bitwise NOT.
* {{{
- * // Select the flags column and negate every bit.
+ * // Scala: select the flags column and negate every bit.
* df.select( ~df("flags") )
* }}}
*/
- def unary_~ : Column = constructColumn(null) { BitwiseNot(expr) }
+ def unary_~ : Column = exprToColumn(BitwiseNot(expr))
/**
* Inversion of boolean expression, i.e. NOT.
* {{
- * // Select rows that are not active (isActive === false)
- * df.select( !df("isActive") )
+ * // Scala: select rows that are not active (isActive === false)
+ * df.filter( !df("isActive") )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( not(df.col("isActive")) );
* }}
*/
- def unary_! : Column = constructColumn(null) { Not(expr) }
+ def unary_! : Column = exprToColumn(Not(expr))
/**
- * Equality test with an expression.
+ * Equality test.
* {{{
- * // The following two both select rows in which colA equals colB.
- * df.select( df("colA") === df("colB") )
- * df.select( df("colA".equalTo(df("colB")) )
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
- def === (other: Column): Column = constructColumn(other) {
- EqualTo(expr, other.expr)
+ def === (other: Any): Column = constructColumn(other) { o =>
+ EqualTo(expr, o.expr)
}
/**
- * Equality test with a literal value.
- * {{{
- * // The following two both select rows in which colA is "Zaharia".
- * df.select( df("colA") === "Zaharia")
- * df.select( df("colA".equalTo("Zaharia") )
- * }}}
- */
- def === (literal: Any): Column = this === lit(literal)
-
- /**
- * Equality test with an expression.
- * {{{
- * // The following two both select rows in which colA equals colB.
- * df.select( df("colA") === df("colB") )
- * df.select( df("colA".equalTo(df("colB")) )
- * }}}
- */
- def equalTo(other: Column): Column = this === other
-
- /**
- * Equality test with a literal value.
+ * Equality test.
* {{{
- * // The following two both select rows in which colA is "Zaharia".
- * df.select( df("colA") === "Zaharia")
- * df.select( df("colA".equalTo("Zaharia") )
+ * // Scala:
+ * df.filter( df("colA") === df("colB") )
+ *
+ * // Java
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
- def equalTo(literal: Any): Column = this === literal
+ def equalTo(other: Any): Column = this === other
/**
- * Inequality test with an expression.
+ * Inequality test.
* {{{
- * // The following two both select rows in which colA does not equal colB.
+ * // Scala:
* df.select( df("colA") !== df("colB") )
* df.select( !(df("colA") === df("colB")) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.filter( not(col("colA").equalTo(col("colB"))) );
* }}}
*/
- def !== (other: Column): Column = constructColumn(other) {
- Not(EqualTo(expr, other.expr))
+ def !== (other: Any): Column = constructColumn(other) { o =>
+ Not(EqualTo(expr, o.expr))
}
/**
- * Inequality test with a literal value.
- * {{{
- * // The following two both select rows in which colA does not equal equal 15.
- * df.select( df("colA") !== 15 )
- * df.select( !(df("colA") === 15) )
- * }}}
- */
- def !== (literal: Any): Column = this !== lit(literal)
-
- /**
- * Greater than an expression.
+ * Greater than.
* {{{
- * // The following selects people older than 21.
- * people.select( people("age") > Literal(21) )
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > 21 )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * people.select( people("age").gt(21) );
* }}}
*/
- def > (other: Column): Column = constructColumn(other) {
- GreaterThan(expr, other.expr)
+ def > (other: Any): Column = constructColumn(other) { o =>
+ GreaterThan(expr, o.expr)
}
/**
- * Greater than a literal value.
+ * Greater than.
* {{{
- * // The following selects people older than 21.
- * people.select( people("age") > 21 )
+ * // Scala: The following selects people older than 21.
+ * people.select( people("age") > lit(21) )
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * people.select( people("age").gt(21) );
* }}}
*/
- def > (literal: Any): Column = this > lit(literal)
+ def gt(other: Any): Column = this > other
/**
- * Less than an expression.
+ * Less than.
* {{{
- * // The following selects people younger than 21.
- * people.select( people("age") < Literal(21) )
+ * // Scala: The following selects people younger than 21.
+ * people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people("age").lt(21) );
* }}}
*/
- def < (other: Column): Column = constructColumn(other) {
- LessThan(expr, other.expr)
+ def < (other: Any): Column = constructColumn(other) { o =>
+ LessThan(expr, o.expr)
}
/**
- * Less than a literal value.
+ * Less than.
* {{{
- * // The following selects people younger than 21.
+ * // Scala: The following selects people younger than 21.
* people.select( people("age") < 21 )
+ *
+ * // Java:
+ * people.select( people("age").lt(21) );
* }}}
*/
- def < (literal: Any): Column = this < lit(literal)
+ def lt(other: Any): Column = this < other
/**
- * Less than or equal to an expression.
+ * Less than or equal to.
* {{{
- * // The following selects people age 21 or younger than 21.
- * people.select( people("age") <= Literal(21) )
+ * // Scala: The following selects people age 21 or younger than 21.
+ * people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people("age").leq(21) );
* }}}
*/
- def <= (other: Column): Column = constructColumn(other) {
- LessThanOrEqual(expr, other.expr)
+ def <= (other: Any): Column = constructColumn(other) { o =>
+ LessThanOrEqual(expr, o.expr)
}
/**
- * Less than or equal to a literal value.
+ * Less than or equal to.
* {{{
- * // The following selects people age 21 or younger than 21.
+ * // Scala: The following selects people age 21 or younger than 21.
* people.select( people("age") <= 21 )
+ *
+ * // Java:
+ * people.select( people("age").leq(21) );
* }}}
*/
- def <= (literal: Any): Column = this <= lit(literal)
+ def leq(other: Any): Column = this <= other
/**
* Greater than or equal to an expression.
* {{{
- * // The following selects people age 21 or older than 21.
- * people.select( people("age") >= Literal(21) )
+ * // Scala: The following selects people age 21 or older than 21.
+ * people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people("age").geq(21) )
* }}}
*/
- def >= (other: Column): Column = constructColumn(other) {
- GreaterThanOrEqual(expr, other.expr)
+ def >= (other: Any): Column = constructColumn(other) { o =>
+ GreaterThanOrEqual(expr, o.expr)
}
/**
- * Greater than or equal to a literal value.
+ * Greater than or equal to an expression.
* {{{
- * // The following selects people age 21 or older than 21.
+ * // Scala: The following selects people age 21 or older than 21.
* people.select( people("age") >= 21 )
+ *
+ * // Java:
+ * people.select( people("age").geq(21) )
* }}}
*/
- def >= (literal: Any): Column = this >= lit(literal)
+ def geq(other: Any): Column = this >= other
/**
- * Equality test with an expression that is safe for null values.
+ * Equality test that is safe for null values.
*/
- def <=> (other: Column): Column = constructColumn(other) {
- other match {
- case null => EqualNullSafe(expr, lit(null).expr)
- case _ => EqualNullSafe(expr, other.expr)
- }
+ def <=> (other: Any): Column = constructColumn(other) { o =>
+ EqualNullSafe(expr, o.expr)
}
/**
- * Equality test with a literal value that is safe for null values.
+ * Equality test that is safe for null values.
*/
- def <=> (literal: Any): Column = this <=> lit(literal)
+ def eqNullSafe(other: Any): Column = this <=> other
/**
* True if the current expression is null.
*/
- def isNull: Column = constructColumn(null) { IsNull(expr) }
+ def isNull: Column = exprToColumn(IsNull(expr))
/**
* True if the current expression is NOT null.
*/
- def isNotNull: Column = constructColumn(null) { IsNotNull(expr) }
+ def isNotNull: Column = exprToColumn(IsNotNull(expr))
/**
- * Boolean OR with an expression.
+ * Boolean OR.
* {{{
- * // The following selects people that are in school or employed.
- * people.select( people("inSchool") || people("isEmployed") )
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
*/
- def || (other: Column): Column = constructColumn(other) {
- Or(expr, other.expr)
+ def || (other: Any): Column = constructColumn(other) { o =>
+ Or(expr, o.expr)
}
/**
- * Boolean OR with a literal value.
+ * Boolean OR.
* {{{
- * // The following selects everything.
- * people.select( people("inSchool") || true )
+ * // Scala: The following selects people that are in school or employed.
+ * people.filter( people("inSchool") || people("isEmployed") )
+ *
+ * // Java:
+ * people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
*/
- def || (literal: Boolean): Column = this || lit(literal)
+ def or(other: Column): Column = this || other
/**
- * Boolean AND with an expression.
+ * Boolean AND.
* {{{
- * // The following selects people that are in school and employed at the same time.
+ * // Scala: The following selects people that are in school and employed at the same time.
* people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people("inSchool").and(people("isEmployed")) );
* }}}
*/
- def && (other: Column): Column = constructColumn(other) {
- And(expr, other.expr)
+ def && (other: Any): Column = constructColumn(other) { o =>
+ And(expr, o.expr)
}
/**
- * Boolean AND with a literal value.
+ * Boolean AND.
* {{{
- * // The following selects people that are in school.
- * people.select( people("inSchool") && true )
+ * // Scala: The following selects people that are in school and employed at the same time.
+ * people.select( people("inSchool") && people("isEmployed") )
+ *
+ * // Java:
+ * people.select( people("inSchool").and(people("isEmployed")) );
* }}}
*/
- def && (literal: Boolean): Column = this && lit(literal)
+ def and(other: Column): Column = this && other
/**
- * Bitwise AND with an expression.
+ * Bitwise AND.
*/
- def & (other: Column): Column = constructColumn(other) {
- BitwiseAnd(expr, other.expr)
+ def & (other: Any): Column = constructColumn(other) { o =>
+ BitwiseAnd(expr, o.expr)
}
/**
- * Bitwise AND with a literal value.
- */
- def & (literal: Any): Column = this & lit(literal)
-
- /**
* Bitwise OR with an expression.
*/
- def | (other: Column): Column = constructColumn(other) {
- BitwiseOr(expr, other.expr)
+ def | (other: Any): Column = constructColumn(other) { o =>
+ BitwiseOr(expr, o.expr)
}
/**
- * Bitwise OR with a literal value.
- */
- def | (literal: Any): Column = this | lit(literal)
-
- /**
* Bitwise XOR with an expression.
*/
- def ^ (other: Column): Column = constructColumn(other) {
- BitwiseXor(expr, other.expr)
+ def ^ (other: Any): Column = constructColumn(other) { o =>
+ BitwiseXor(expr, o.expr)
}
/**
- * Bitwise XOR with a literal value.
- */
- def ^ (literal: Any): Column = this ^ lit(literal)
-
- /**
* Sum of this expression and another expression.
* {{{
- * // The following selects the sum of a person's height and weight.
+ * // Scala: The following selects the sum of a person's height and weight.
* people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").plus(people("weight")) );
* }}}
*/
- def + (other: Column): Column = constructColumn(other) {
- Add(expr, other.expr)
+ def + (other: Any): Column = constructColumn(other) { o =>
+ Add(expr, o.expr)
}
/**
* Sum of this expression and another expression.
* {{{
- * // The following selects the sum of a person's height and 10.
- * people.select( people("height") + 10 )
+ * // Scala: The following selects the sum of a person's height and weight.
+ * people.select( people("height") + people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").plus(people("weight")) );
* }}}
*/
- def + (literal: Any): Column = this + lit(literal)
+ def plus(other: Any): Column = this + other
/**
* Subtraction. Subtract the other expression from this expression.
* {{{
- * // The following selects the difference between people's height and their weight.
+ * // Scala: The following selects the difference between people's height and their weight.
* people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").minus(people("weight")) );
* }}}
*/
- def - (other: Column): Column = constructColumn(other) {
- Subtract(expr, other.expr)
+ def - (other: Any): Column = constructColumn(other) { o =>
+ Subtract(expr, o.expr)
}
/**
- * Subtraction. Subtract a literal value from this expression.
+ * Subtraction. Subtract the other expression from this expression.
* {{{
- * // The following selects a person's height and subtract it by 10.
- * people.select( people("height") - 10 )
+ * // Scala: The following selects the difference between people's height and their weight.
+ * people.select( people("height") - people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").minus(people("weight")) );
* }}}
*/
- def - (literal: Any): Column = this - lit(literal)
+ def minus(other: Any): Column = this - other
/**
* Multiplication of this expression and another expression.
* {{{
- * // The following multiplies a person's height by their weight.
+ * // Scala: The following multiplies a person's height by their weight.
* people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").multiply(people("weight")) );
* }}}
*/
- def * (other: Column): Column = constructColumn(other) {
- Multiply(expr, other.expr)
+ def * (other: Any): Column = constructColumn(other) { o =>
+ Multiply(expr, o.expr)
}
/**
- * Multiplication this expression and a literal value.
+ * Multiplication of this expression and another expression.
* {{{
- * // The following multiplies a person's height by 10.
- * people.select( people("height") * 10 )
+ * // Scala: The following multiplies a person's height by their weight.
+ * people.select( people("height") * people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").multiply(people("weight")) );
* }}}
*/
- def * (literal: Any): Column = this * lit(literal)
+ def multiply(other: Any): Column = this * other
/**
* Division this expression by another expression.
* {{{
- * // The following divides a person's height by their weight.
+ * // Scala: The following divides a person's height by their weight.
* people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").divide(people("weight")) );
* }}}
*/
- def / (other: Column): Column = constructColumn(other) {
- Divide(expr, other.expr)
+ def / (other: Any): Column = constructColumn(other) { o =>
+ Divide(expr, o.expr)
}
/**
- * Division this expression by a literal value.
+ * Division this expression by another expression.
* {{{
- * // The following divides a person's height by 10.
- * people.select( people("height") / 10 )
+ * // Scala: The following divides a person's height by their weight.
+ * people.select( people("height") / people("weight") )
+ *
+ * // Java:
+ * people.select( people("height").divide(people("weight")) );
* }}}
*/
- def / (literal: Any): Column = this / lit(literal)
+ def divide(other: Any): Column = this / other
/**
* Modulo (a.k.a. remainder) expression.
*/
- def % (other: Column): Column = constructColumn(other) {
- Remainder(expr, other.expr)
+ def % (other: Any): Column = constructColumn(other) { o =>
+ Remainder(expr, o.expr)
}
/**
* Modulo (a.k.a. remainder) expression.
*/
- def % (literal: Any): Column = this % lit(literal)
-
+ def mod(other: Any): Column = this % other
/**
* A boolean expression that is evaluated to true if the value of this expression is contained
@@ -469,27 +508,19 @@ trait Column extends DataFrame {
new IncomputableColumn(In(expr, list.map(_.expr)))
}
- def like(literal: String): Column = constructColumn(null) {
- Like(expr, lit(literal).expr)
- }
+ def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr))
- def rlike(literal: String): Column = constructColumn(null) {
- RLike(expr, lit(literal).expr)
- }
+ def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr))
/**
* An expression that gets an item at position `ordinal` out of an array.
*/
- def getItem(ordinal: Int): Column = constructColumn(null) {
- GetItem(expr, Literal(ordinal))
- }
+ def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal)))
/**
* An expression that gets a field by name in a [[StructField]].
*/
- def getField(fieldName: String): Column = constructColumn(null) {
- GetField(expr, fieldName)
- }
+ def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName))
/**
* An expression that returns a substring.
@@ -507,20 +538,18 @@ trait Column extends DataFrame {
*/
def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
- def contains(other: Column): Column = constructColumn(other) {
- Contains(expr, other.expr)
+ def contains(other: Any): Column = constructColumn(other) { o =>
+ Contains(expr, o.expr)
}
- def contains(literal: Any): Column = this.contains(lit(literal))
-
- def startsWith(other: Column): Column = constructColumn(other) {
- StartsWith(expr, other.expr)
+ def startsWith(other: Column): Column = constructColumn(other) { o =>
+ StartsWith(expr, o.expr)
}
def startsWith(literal: String): Column = this.startsWith(lit(literal))
- def endsWith(other: Column): Column = constructColumn(other) {
- EndsWith(expr, other.expr)
+ def endsWith(other: Column): Column = constructColumn(other) { o =>
+ EndsWith(expr, o.expr)
}
def endsWith(literal: String): Column = this.endsWith(lit(literal))
@@ -532,7 +561,7 @@ trait Column extends DataFrame {
* df.select($"colA".as("colB"))
* }}}
*/
- override def as(alias: String): Column = constructColumn(null) { Alias(expr, alias)() }
+ override def as(alias: String): Column = exprToColumn(Alias(expr, alias)())
/**
* Casts the column to a different data type.
@@ -545,7 +574,7 @@ trait Column extends DataFrame {
* df.select(df("colA").cast("int"))
* }}}
*/
- def cast(to: DataType): Column = constructColumn(null) { Cast(expr, to) }
+ def cast(to: DataType): Column = exprToColumn(Cast(expr, to))
/**
* Casts the column to a different data type, using the canonical string representation
@@ -556,7 +585,7 @@ trait Column extends DataFrame {
* df.select(df("colA").cast("int"))
* }}}
*/
- def cast(to: String): Column = constructColumn(null) {
+ def cast(to: String): Column = exprToColumn(
Cast(expr, to.toLowerCase match {
case "string" => StringType
case "boolean" => BooleanType
@@ -571,11 +600,11 @@ trait Column extends DataFrame {
case "timestamp" => TimestampType
case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
})
- }
+ )
- def desc: Column = constructColumn(null) { SortOrder(expr, Descending) }
+ def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false)
- def asc: Column = constructColumn(null) { SortOrder(expr, Ascending) }
+ def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 5920852e8c..f3bc07ae52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
private[sql] object DataFrame {
@@ -138,7 +139,13 @@ trait DataFrame extends RDDApi[Row] {
* a full outer join between `df1` and `df2`.
*
* {{{
+ * // Scala:
+ * import org.apache.spark.sql.dsl._
* df1.join(df2, "outer", $"df1Key" === $"df2Key")
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df1.join(df2, "outer", col("df1Key") === col("df2Key"));
* }}}
*
* @param right Right side of the join.
@@ -185,7 +192,12 @@ trait DataFrame extends RDDApi[Row] {
/**
* Selects column based on the column name and return it as a [[Column]].
*/
- def apply(colName: String): Column
+ def apply(colName: String): Column = col(colName)
+
+ /**
+ * Selects column based on the column name and return it as a [[Column]].
+ */
+ def col(colName: String): Column
/**
* Selects a set of expressions, wrapped in a Product.
@@ -297,24 +309,41 @@ trait DataFrame extends RDDApi[Row] {
def groupBy(col1: String, cols: String*): GroupedDataFrame
/**
- * Aggregates on the entire [[DataFrame]] without groups.
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ groupBy().agg(aggExpr, aggExprs :_*)
+ }
+
+ /**
+ * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- def agg(exprs: Map[String, String]): DataFrame
+ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
- * Aggregates on the entire [[DataFrame]] without groups.
+ * (Java-specific) Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- def agg(exprs: java.util.Map[String, String]): DataFrame
+ def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
* Aggregates on the entire [[DataFrame]] without groups.
@@ -325,7 +354,7 @@ trait DataFrame extends RDDApi[Row] {
* }}
*/
@scala.annotation.varargs
- def agg(expr: Column, exprs: Column*): DataFrame
+ def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
/**
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
@@ -366,7 +395,9 @@ trait DataFrame extends RDDApi[Row] {
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
*/
- def sample(withReplacement: Boolean, fraction: Double): DataFrame
+ def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 49fd131534..0b0623dc1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan}
import org.apache.spark.sql.types.{NumericType, StructType}
-import org.apache.spark.util.Utils
/**
@@ -148,7 +147,7 @@ private[sql] class DataFrameImpl protected[sql](
sort(sortExpr, sortExprs :_*)
}
- override def apply(colName: String): Column = colName match {
+ override def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(schema.fieldNames.map(resolve)))
case _ =>
@@ -201,18 +200,6 @@ private[sql] class DataFrameImpl protected[sql](
new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
}
- override def agg(exprs: Map[String, String]): DataFrame = {
- groupBy().agg(exprs)
- }
-
- override def agg(exprs: java.util.Map[String, String]): DataFrame = {
- agg(exprs.toMap)
- }
-
- override def agg(expr: Column, exprs: Column*): DataFrame = {
- groupBy().agg(expr, exprs :_*)
- }
-
override def limit(n: Int): DataFrame = {
Limit(Literal(n), logicalPlan)
}
@@ -233,10 +220,6 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}
- override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
- sample(withReplacement, fraction, Utils.random.nextLong)
- }
-
/////////////////////////////////////////////////////////////////////////////
override def addColumn(colName: String, col: Column): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index b4279a32ff..71365c776d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -36,21 +36,6 @@ object Dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
- // /**
- // * An implicit conversion that turns a RDD of product into a [[DataFrame]].
- // *
- // * This method requires an implicit SQLContext in scope. For example:
- // * {{{
- // * implicit val sqlContext: SQLContext = ...
- // * val rdd: RDD[(Int, String)] = ...
- // * rdd.toDataFrame // triggers the implicit here
- // * }}}
- // */
- // implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
- // : DataFrame = {
- // context.createDataFrame(rdd)
- // }
-
/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
@@ -72,10 +57,16 @@ object Dsl {
/**
* Creates a [[Column]] of literal value.
+ *
+ * The passed in object is returned directly if it is already a [[Column]].
+ * If the object is a Scala Symbol, it is converted into a [[Column]] also.
+ * Otherwise, a new [[Column]] is created to represent the literal value.
*/
def lit(literal: Any): Column = {
- if (literal.isInstanceOf[Symbol]) {
- return new ColumnName(literal.asInstanceOf[Symbol].name)
+ literal match {
+ case c: Column => return c
+ case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
+ case _ => // continue
}
val literalExpr = literal match {
@@ -100,27 +91,82 @@ object Dsl {
Column(literalExpr)
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /** Aggregate function: returns the sum of all values in the expression. */
def sum(e: Column): Column = Sum(e.expr)
+
+ /** Aggregate function: returns the sum of distinct values in the expression. */
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+
+ /** Aggregate function: returns the number of items in a group. */
def count(e: Column): Column = Count(e.expr)
+ /** Aggregate function: returns the number of distinct items in a group. */
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
- def approxCountDistinct(e: Column, rsd: Double): Column =
- ApproxCountDistinct(e.expr, rsd)
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
+ def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)
+
+ /** Aggregate function: returns the average of the values in a group. */
def avg(e: Column): Column = Average(e.expr)
+
+ /** Aggregate function: returns the first value in a group. */
def first(e: Column): Column = First(e.expr)
+
+ /** Aggregate function: returns the last value in a group. */
def last(e: Column): Column = Last(e.expr)
+
+ /** Aggregate function: returns the minimum value of the expression in a group. */
def min(e: Column): Column = Min(e.expr)
+
+ /** Aggregate function: returns the maximum value of the expression in a group. */
def max(e: Column): Column = Max(e.expr)
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Select the amount column and negates all values.
+ * // Scala:
+ * df.select( -df("amount") )
+ *
+ * // Java:
+ * df.select( negate(df.col("amount")) );
+ * }}}
+ */
+ def negate(e: Column): Column = -e
+
+ /**
+ * Inversion of boolean expression, i.e. NOT.
+ * {{
+ * // Scala: select rows that are not active (isActive === false)
+ * df.filter( !df("isActive") )
+ *
+ * // Java:
+ * df.filter( not(df.col("isActive")) );
+ * }}
+ */
+ def not(e: Column): Column = !e
+
+ /** Converts a string expression to upper case. */
def upper(e: Column): Column = Upper(e.expr)
+
+ /** Converts a string exprsesion to lower case. */
def lower(e: Column): Column = Lower(e.expr)
+
+ /** Computes the square root of the specified float value. */
def sqrt(e: Column): Column = Sqrt(e.expr)
+
+ /** Computes the absolutle value. */
def abs(e: Column): Column = Abs(e.expr)
/**
@@ -131,6 +177,9 @@ object Dsl {
cols.toList.toSeq
}
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
// scalastyle:off
/* Use the following code to generate:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
index 6d0f3e8ce3..7963cb0312 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql
-import java.util.{List => JList}
-
import scala.language.implicitConversions
import scala.collection.JavaConversions._
@@ -59,15 +57,32 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
- * [[DataFrame]] will also contain the grouping columns.
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
+ *
+ * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
+ * {{{
+ * // Selects the age of the oldest employee and the aggregate expense for each department
+ * df.groupBy("department").agg(
+ * "age" -> "max",
+ * "expense" -> "sum"
+ * )
+ * }}}
+ */
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
+ agg((aggExpr +: aggExprs).toMap)
+ }
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* df.groupBy("department").agg(Map(
- * "age" -> "max"
- * "sum" -> "expense"
+ * "age" -> "max",
+ * "expense" -> "sum"
* ))
* }}}
*/
@@ -79,16 +94,17 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
}
/**
- * Compute aggregates by specifying a map from column name to aggregate methods. The resulting
- * [[DataFrame]] will also contain the grouping columns.
+ * (Java-specific) Compute aggregates by specifying a map from column name to
+ * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
- * df.groupBy("department").agg(Map(
- * "age" -> "max"
- * "sum" -> "expense"
- * ))
+ * import com.google.common.collect.ImmutableMap;
+ * df.groupBy("department").agg(ImmutableMap.<String, String>builder()
+ * .put("age", "max")
+ * .put("expense", "sum")
+ * .build());
* }}}
*/
def agg(exprs: java.util.Map[String, String]): DataFrame = {
@@ -103,8 +119,14 @@ class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expr
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
+ *
+ * // Scala:
* import org.apache.spark.sql.dsl._
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
+ *
+ * // Java:
+ * import static org.apache.spark.sql.Dsl.*;
+ * df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense")));
* }}}
*/
@scala.annotation.varargs
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 9b051de68f..ba5c7355b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -72,7 +72,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
- override def apply(colName: String): Column = err()
+ override def col(colName: String): Column = err()
override def apply(projection: Product): DataFrame = err()
@@ -90,12 +90,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def groupBy(col1: String, cols: String*): GroupedDataFrame = err()
- override def agg(exprs: Map[String, String]): DataFrame = err()
-
- override def agg(exprs: java.util.Map[String, String]): DataFrame = err()
-
- override def agg(expr: Column, exprs: Column*): DataFrame = err()
-
override def limit(n: Int): DataFrame = err()
override def unionAll(other: DataFrame): DataFrame = err()
@@ -106,8 +100,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
- override def sample(withReplacement: Boolean, fraction: Double): DataFrame = err()
-
/////////////////////////////////////////////////////////////////////////////
override def addColumn(colName: String, col: Column): DataFrame = err()