aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-02 19:01:47 -0800
committerReynold Xin <rxin@databricks.com>2015-02-02 19:01:47 -0800
commit554403fd913685da879cf6a280c58a9fad19448a (patch)
treeb3a63382e7385fa1480b54707b348b0bde02190d
parenteccb9fbb2d1bf6f7c65fb4f017e9205bb3034ec6 (diff)
downloadspark-554403fd913685da879cf6a280c58a9fad19448a.tar.gz
spark-554403fd913685da879cf6a280c58a9fad19448a.tar.bz2
spark-554403fd913685da879cf6a280c58a9fad19448a.zip
[SQL] Improve DataFrame API error reporting
1. Throw UnsupportedOperationException if a Column is not computable. 2. Perform eager analysis on DataFrame so we can catch errors when they happen (not when an action is run). Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4296 from rxin/col-computability and squashes the following commits: 6527b86 [Reynold Xin] Merge pull request #8 from davies/col-computability fd92bc7 [Reynold Xin] Merge branch 'master' into col-computability f79034c [Davies Liu] fix python tests 5afe1ff [Reynold Xin] Fix scala test. 17f6bae [Reynold Xin] Various fixes. b932e86 [Reynold Xin] Added eager analysis for error reporting. e6f00b8 [Reynold Xin] [SQL][API] ComputableColumn vs IncomputableColumn
-rw-r--r--python/pyspark/sql.py75
-rw-r--r--python/pyspark/tests.py6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala241
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala292
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala331
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala160
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala13
20 files changed, 896 insertions, 381 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 3f2d7ac825..32bff0c7e8 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2124,6 +2124,10 @@ class DataFrame(object):
return rs[0] if rs else None
return self.take(n)
+ def first(self):
+ """ Return the first row. """
+ return self.head()
+
def tail(self):
raise NotImplemented
@@ -2159,7 +2163,7 @@ class DataFrame(object):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
@@ -2189,7 +2193,7 @@ class DataFrame(object):
else:
cols = [c._jc for c in cols]
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
def agg(self, *exprs):
@@ -2278,14 +2282,17 @@ class GroupedDataFrame(object):
:param exprs: list or aggregate columns or a map from column
name to agregate methods.
"""
+ assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jmap = MapConverter().convert(exprs[0],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(jmap)
else:
# Columns
- assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
- jdf = self._jdf.agg(*exprs)
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+ jcols = ListConverter().convert([c._jc for c in exprs[1:]],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Column(name)
+ return sc._jvm.IncomputableColumn(name)
def _scalaMethod(name):
@@ -2371,7 +2378,7 @@ def _unary_op(name):
return _
-def _bin_op(name, pass_literal_through=False):
+def _bin_op(name, pass_literal_through=True):
""" Create a method for given binary operator
Keyword arguments:
@@ -2465,10 +2472,10 @@ class Column(DataFrame):
# __getattr__ = _bin_op("getField")
# string methods
- rlike = _bin_op("rlike", pass_literal_through=True)
- like = _bin_op("like", pass_literal_through=True)
- startswith = _bin_op("startsWith", pass_literal_through=True)
- endswith = _bin_op("endsWith", pass_literal_through=True)
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
upper = _unary_op("upper")
lower = _unary_op("lower")
@@ -2476,7 +2483,6 @@ class Column(DataFrame):
if type(startPos) != type(pos):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):
-
jc = self._jc.substr(startPos, pos)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, pos._jc)
@@ -2507,16 +2513,21 @@ class Column(DataFrame):
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
+
+
def _aggregate_func(name):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
- if isinstance(col, Column):
- jcol = col._jc
- else:
- jcol = _create_column_from_name(col)
- jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
+ jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
return Column(jc)
+
return staticmethod(_)
@@ -2524,13 +2535,31 @@ class Aggregator(object):
"""
A collections of builtin aggregators
"""
- max = _aggregate_func("max")
- min = _aggregate_func("min")
- avg = mean = _aggregate_func("mean")
- sum = _aggregate_func("sum")
- first = _aggregate_func("first")
- last = _aggregate_func("last")
- count = _aggregate_func("count")
+ AGGS = [
+ 'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
+ 'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
+ ]
+ for _name in AGGS:
+ locals()[_name] = _aggregate_func(_name)
+ del _name
+
+ @staticmethod
+ def countDistinct(col, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
+ sc._jvm.Dsl.toColumns(jcols))
+ return Column(jc)
+
+ @staticmethod
+ def approxCountDistinct(col, rsd=None):
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
def _test():
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bec1961f26..fef6c92875 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1029,9 +1029,11 @@ class SQLTests(ReusedPySparkTestCase):
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- # TODO(davies): fix aggregators
+
from pyspark.sql import Aggregator as Agg
- # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+ self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
def test_help_command(self):
# Regression test for SPARK-5464
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index 6ab99aa388..defdcb2b70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -822,7 +822,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* have a name matching the given name, `null` will be returned.
*/
def apply(name: String): StructField = {
- nameToField.getOrElse(name, throw new IllegalArgumentException(s"Field $name does not exist."))
+ nameToField.getOrElse(name,
+ throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 174c403059..6f48d7c3fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,23 +17,26 @@
package org.apache.spark.sql
+import scala.annotation.tailrec
import scala.language.implicitConversions
import org.apache.spark.sql.Dsl.lit
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
import org.apache.spark.sql.types._
-object Column {
- /**
- * Creates a [[Column]] based on the given column name. Same as [[Dsl.col]].
- */
- def apply(colName: String): Column = new Column(colName)
+private[sql] object Column {
+
+ def apply(colName: String): Column = new IncomputableColumn(colName)
+
+ def apply(expr: Expression): Column = new IncomputableColumn(expr)
+
+ def apply(sqlContext: SQLContext, plan: LogicalPlan, expr: Expression): Column = {
+ new ComputableColumn(sqlContext, plan, expr)
+ }
- /** For internal pattern matching. */
- private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
+ def unapply(col: Column): Option[Expression] = Some(col.expr)
}
@@ -53,44 +56,42 @@ object Column {
*
*/
// TODO: Improve documentation.
-class Column(
- sqlContext: Option[SQLContext],
- plan: Option[LogicalPlan],
- protected[sql] val expr: Expression)
- extends DataFrame(sqlContext, plan) with ExpressionApi {
+trait Column extends DataFrame with ExpressionApi {
- /** Turns a Catalyst expression into a `Column`. */
- protected[sql] def this(expr: Expression) = this(None, None, expr)
+ protected[sql] def expr: Expression
/**
- * Creates a new `Column` expression based on a column or attribute name.
- * The resolution of this is the same as SQL. For example:
- *
- * - "colName" becomes an expression selecting the column named "colName".
- * - "*" becomes an expression selecting all columns.
- * - "df.*" becomes an expression selecting all columns in data frame "df".
+ * Returns true iff the [[Column]] is computable.
*/
- def this(name: String) = this(name match {
- case "*" => UnresolvedStar(None)
- case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
- case _ => UnresolvedAttribute(name)
- })
+ def isComputable: Boolean
- override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined
+ private def constructColumn(other: Column)(newExpr: Expression): Column = {
+ // Removes all the top level projection and subquery so we can get to the underlying plan.
+ @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
+ case Project(_, child) => stripProject(child)
+ case Subquery(_, child) => stripProject(child)
+ case _ => p
+ }
- /**
- * An implicit conversion function internal to this class. This function creates a new Column
- * based on an expression. If the expression itself is not named, it aliases the expression
- * by calling it "col".
- */
- private[this] implicit def toColumn(expr: Expression): Column = {
- val projectedPlan = plan.map { p =>
- Project(Seq(expr match {
+ def computableCol(baseCol: ComputableColumn, expr: Expression) = {
+ val plan = Project(Seq(expr match {
case named: NamedExpression => named
case unnamed: Expression => Alias(unnamed, "col")()
- }), p)
+ }), baseCol.plan)
+ Column(baseCol.sqlContext, plan, expr)
+ }
+
+ (this, other) match {
+ case (left: ComputableColumn, right: ComputableColumn) =>
+ if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
+ computableCol(right, newExpr)
+ } else {
+ Column(newExpr)
+ }
+ case (left: ComputableColumn, _) => computableCol(left, newExpr)
+ case (_, right: ComputableColumn) => computableCol(right, newExpr)
+ case (_, _) => Column(newExpr)
}
- new Column(sqlContext, projectedPlan, expr)
}
/**
@@ -100,7 +101,7 @@ class Column(
* df.select( -df("amount") )
* }}}
*/
- override def unary_- : Column = UnaryMinus(expr)
+ override def unary_- : Column = constructColumn(null) { UnaryMinus(expr) }
/**
* Bitwise NOT.
@@ -109,7 +110,7 @@ class Column(
* df.select( ~df("flags") )
* }}}
*/
- override def unary_~ : Column = BitwiseNot(expr)
+ override def unary_~ : Column = constructColumn(null) { BitwiseNot(expr) }
/**
* Inversion of boolean expression, i.e. NOT.
@@ -118,7 +119,7 @@ class Column(
* df.select( !df("isActive") )
* }}
*/
- override def unary_! : Column = Not(expr)
+ override def unary_! : Column = constructColumn(null) { Not(expr) }
/**
@@ -129,7 +130,9 @@ class Column(
* df.select( df("colA".equalTo(df("colB")) )
* }}}
*/
- override def === (other: Column): Column = EqualTo(expr, other.expr)
+ override def === (other: Column): Column = constructColumn(other) {
+ EqualTo(expr, other.expr)
+ }
/**
* Equality test with a literal value.
@@ -169,7 +172,9 @@ class Column(
* df.select( !(df("colA") === df("colB")) )
* }}}
*/
- override def !== (other: Column): Column = Not(EqualTo(expr, other.expr))
+ override def !== (other: Column): Column = constructColumn(other) {
+ Not(EqualTo(expr, other.expr))
+ }
/**
* Inequality test with a literal value.
@@ -188,7 +193,9 @@ class Column(
* people.select( people("age") > Literal(21) )
* }}}
*/
- override def > (other: Column): Column = GreaterThan(expr, other.expr)
+ override def > (other: Column): Column = constructColumn(other) {
+ GreaterThan(expr, other.expr)
+ }
/**
* Greater than a literal value.
@@ -206,7 +213,9 @@ class Column(
* people.select( people("age") < Literal(21) )
* }}}
*/
- override def < (other: Column): Column = LessThan(expr, other.expr)
+ override def < (other: Column): Column = constructColumn(other) {
+ LessThan(expr, other.expr)
+ }
/**
* Less than a literal value.
@@ -224,7 +233,9 @@ class Column(
* people.select( people("age") <= Literal(21) )
* }}}
*/
- override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr)
+ override def <= (other: Column): Column = constructColumn(other) {
+ LessThanOrEqual(expr, other.expr)
+ }
/**
* Less than or equal to a literal value.
@@ -242,7 +253,9 @@ class Column(
* people.select( people("age") >= Literal(21) )
* }}}
*/
- override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr)
+ override def >= (other: Column): Column = constructColumn(other) {
+ GreaterThanOrEqual(expr, other.expr)
+ }
/**
* Greater than or equal to a literal value.
@@ -256,9 +269,11 @@ class Column(
/**
* Equality test with an expression that is safe for null values.
*/
- override def <=> (other: Column): Column = other match {
- case null => EqualNullSafe(expr, lit(null).expr)
- case _ => EqualNullSafe(expr, other.expr)
+ override def <=> (other: Column): Column = constructColumn(other) {
+ other match {
+ case null => EqualNullSafe(expr, lit(null).expr)
+ case _ => EqualNullSafe(expr, other.expr)
+ }
}
/**
@@ -269,12 +284,12 @@ class Column(
/**
* True if the current expression is null.
*/
- override def isNull: Column = IsNull(expr)
+ override def isNull: Column = constructColumn(null) { IsNull(expr) }
/**
* True if the current expression is NOT null.
*/
- override def isNotNull: Column = IsNotNull(expr)
+ override def isNotNull: Column = constructColumn(null) { IsNotNull(expr) }
/**
* Boolean OR with an expression.
@@ -283,7 +298,9 @@ class Column(
* people.select( people("inSchool") || people("isEmployed") )
* }}}
*/
- override def || (other: Column): Column = Or(expr, other.expr)
+ override def || (other: Column): Column = constructColumn(other) {
+ Or(expr, other.expr)
+ }
/**
* Boolean OR with a literal value.
@@ -301,7 +318,9 @@ class Column(
* people.select( people("inSchool") && people("isEmployed") )
* }}}
*/
- override def && (other: Column): Column = And(expr, other.expr)
+ override def && (other: Column): Column = constructColumn(other) {
+ And(expr, other.expr)
+ }
/**
* Boolean AND with a literal value.
@@ -315,7 +334,9 @@ class Column(
/**
* Bitwise AND with an expression.
*/
- override def & (other: Column): Column = BitwiseAnd(expr, other.expr)
+ override def & (other: Column): Column = constructColumn(other) {
+ BitwiseAnd(expr, other.expr)
+ }
/**
* Bitwise AND with a literal value.
@@ -325,7 +346,9 @@ class Column(
/**
* Bitwise OR with an expression.
*/
- override def | (other: Column): Column = BitwiseOr(expr, other.expr)
+ override def | (other: Column): Column = constructColumn(other) {
+ BitwiseOr(expr, other.expr)
+ }
/**
* Bitwise OR with a literal value.
@@ -335,7 +358,9 @@ class Column(
/**
* Bitwise XOR with an expression.
*/
- override def ^ (other: Column): Column = BitwiseXor(expr, other.expr)
+ override def ^ (other: Column): Column = constructColumn(other) {
+ BitwiseXor(expr, other.expr)
+ }
/**
* Bitwise XOR with a literal value.
@@ -349,7 +374,9 @@ class Column(
* people.select( people("height") + people("weight") )
* }}}
*/
- override def + (other: Column): Column = Add(expr, other.expr)
+ override def + (other: Column): Column = constructColumn(other) {
+ Add(expr, other.expr)
+ }
/**
* Sum of this expression and another expression.
@@ -367,7 +394,9 @@ class Column(
* people.select( people("height") - people("weight") )
* }}}
*/
- override def - (other: Column): Column = Subtract(expr, other.expr)
+ override def - (other: Column): Column = constructColumn(other) {
+ Subtract(expr, other.expr)
+ }
/**
* Subtraction. Subtract a literal value from this expression.
@@ -385,7 +414,9 @@ class Column(
* people.select( people("height") * people("weight") )
* }}}
*/
- override def * (other: Column): Column = Multiply(expr, other.expr)
+ override def * (other: Column): Column = constructColumn(other) {
+ Multiply(expr, other.expr)
+ }
/**
* Multiplication this expression and a literal value.
@@ -403,7 +434,9 @@ class Column(
* people.select( people("height") / people("weight") )
* }}}
*/
- override def / (other: Column): Column = Divide(expr, other.expr)
+ override def / (other: Column): Column = constructColumn(other) {
+ Divide(expr, other.expr)
+ }
/**
* Division this expression by a literal value.
@@ -417,7 +450,9 @@ class Column(
/**
* Modulo (a.k.a. remainder) expression.
*/
- override def % (other: Column): Column = Remainder(expr, other.expr)
+ override def % (other: Column): Column = constructColumn(other) {
+ Remainder(expr, other.expr)
+ }
/**
* Modulo (a.k.a. remainder) expression.
@@ -430,29 +465,40 @@ class Column(
* by the evaluated values of the arguments.
*/
@scala.annotation.varargs
- override def in(list: Column*): Column = In(expr, list.map(_.expr))
+ override def in(list: Column*): Column = {
+ new IncomputableColumn(In(expr, list.map(_.expr)))
+ }
- override def like(literal: String): Column = Like(expr, lit(literal).expr)
+ override def like(literal: String): Column = constructColumn(null) {
+ Like(expr, lit(literal).expr)
+ }
- override def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
+ override def rlike(literal: String): Column = constructColumn(null) {
+ RLike(expr, lit(literal).expr)
+ }
/**
* An expression that gets an item at position `ordinal` out of an array.
*/
- override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
+ override def getItem(ordinal: Int): Column = constructColumn(null) {
+ GetItem(expr, Literal(ordinal))
+ }
/**
* An expression that gets a field by name in a [[StructField]].
*/
- override def getField(fieldName: String): Column = GetField(expr, fieldName)
+ override def getField(fieldName: String): Column = constructColumn(null) {
+ GetField(expr, fieldName)
+ }
/**
* An expression that returns a substring.
* @param startPos expression for the starting position.
* @param len expression for the length of the substring.
*/
- override def substr(startPos: Column, len: Column): Column =
- Substring(expr, startPos.expr, len.expr)
+ override def substr(startPos: Column, len: Column): Column = {
+ new IncomputableColumn(Substring(expr, startPos.expr, len.expr))
+ }
/**
* An expression that returns a substring.
@@ -461,16 +507,21 @@ class Column(
*/
override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len))
- override def contains(other: Column): Column = Contains(expr, other.expr)
+ override def contains(other: Column): Column = constructColumn(other) {
+ Contains(expr, other.expr)
+ }
override def contains(literal: Any): Column = this.contains(lit(literal))
-
- override def startsWith(other: Column): Column = StartsWith(expr, other.expr)
+ override def startsWith(other: Column): Column = constructColumn(other) {
+ StartsWith(expr, other.expr)
+ }
override def startsWith(literal: String): Column = this.startsWith(lit(literal))
- override def endsWith(other: Column): Column = EndsWith(expr, other.expr)
+ override def endsWith(other: Column): Column = constructColumn(other) {
+ EndsWith(expr, other.expr)
+ }
override def endsWith(literal: String): Column = this.endsWith(lit(literal))
@@ -481,7 +532,7 @@ class Column(
* df.select($"colA".as("colB"))
* }}}
*/
- override def as(alias: String): Column = Alias(expr, alias)()
+ override def as(alias: String): Column = constructColumn(null) { Alias(expr, alias)() }
/**
* Casts the column to a different data type.
@@ -494,7 +545,7 @@ class Column(
* df.select(df("colA").cast("int"))
* }}}
*/
- override def cast(to: DataType): Column = Cast(expr, to)
+ override def cast(to: DataType): Column = constructColumn(null) { Cast(expr, to) }
/**
* Casts the column to a different data type, using the canonical string representation
@@ -505,28 +556,30 @@ class Column(
* df.select(df("colA").cast("int"))
* }}}
*/
- override def cast(to: String): Column = Cast(expr, to.toLowerCase match {
- case "string" => StringType
- case "boolean" => BooleanType
- case "byte" => ByteType
- case "short" => ShortType
- case "int" => IntegerType
- case "long" => LongType
- case "float" => FloatType
- case "double" => DoubleType
- case "decimal" => DecimalType.Unlimited
- case "date" => DateType
- case "timestamp" => TimestampType
- case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
- })
-
- override def desc: Column = SortOrder(expr, Descending)
-
- override def asc: Column = SortOrder(expr, Ascending)
+ override def cast(to: String): Column = constructColumn(null) {
+ Cast(expr, to.toLowerCase match {
+ case "string" => StringType
+ case "boolean" => BooleanType
+ case "byte" => ByteType
+ case "short" => ShortType
+ case "int" => IntegerType
+ case "long" => LongType
+ case "float" => FloatType
+ case "double" => DoubleType
+ case "decimal" => DecimalType.Unlimited
+ case "date" => DateType
+ case "timestamp" => TimestampType
+ case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
+ })
+ }
+
+ override def desc: Column = constructColumn(null) { SortOrder(expr, Descending) }
+
+ override def asc: Column = constructColumn(null) { SortOrder(expr, Ascending) }
}
-class ColumnName(name: String) extends Column(name) {
+class ColumnName(name: String) extends IncomputableColumn(name) {
/** Creates a new AttributeReference of type boolean */
def boolean: StructField = StructField(name, BooleanType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
new file mode 100644
index 0000000000..ac479b26a7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
@@ -0,0 +1,33 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+
+private[sql] class ComputableColumn protected[sql](
+ sqlContext: SQLContext,
+ protected[sql] val plan: LogicalPlan,
+ protected[sql] val expr: Expression)
+ extends DataFrameImpl(sqlContext, plan) with Column {
+
+ override def isComputable: Boolean = true
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 5d42d4428d..385e1ec74f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,26 +19,21 @@ package org.apache.spark.sql
import java.util.{List => JList}
-import scala.language.implicitConversions
import scala.reflect.ClassTag
-import scala.collection.JavaConversions._
-import com.fasterxml.jackson.core.JsonFactory
-
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.types.StructType
+
+
+private[sql] object DataFrame {
+ def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrameImpl(sqlContext, logicalPlan)
+ }
+}
/**
@@ -78,50 +73,14 @@ import org.apache.spark.util.Utils
* }}}
*/
// TODO: Improve documentation.
-class DataFrame protected[sql](
- val sqlContext: SQLContext,
- private val baseLogicalPlan: LogicalPlan,
- operatorsEnabled: Boolean)
- extends DataFrameSpecificApi with RDDApi[Row] {
-
- protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) =
- this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined)
-
- protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true)
-
- @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan)
-
- @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- baseLogicalPlan
- }
+trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] {
- /**
- * An implicit conversion function internal to this class for us to avoid doing
- * "new DataFrame(...)" everywhere.
- */
- private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
- new DataFrame(sqlContext, logicalPlan, true)
- }
+ val sqlContext: SQLContext
- /** Returns the list of numeric columns, useful for doing aggregation. */
- protected[sql] def numericColumns: Seq[Expression] = {
- schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
- queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
- }
- }
+ @DeveloperApi
+ def queryExecution: SQLContext#QueryExecution
- /** Resolves a column name into a Catalyst [[NamedExpression]]. */
- protected[sql] def resolve(colName: String): NamedExpression = {
- queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
- throw new RuntimeException(
- s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
- }
- }
+ protected[sql] def logicalPlan: LogicalPlan
/** Left here for compatibility reasons. */
@deprecated("1.3.0", "use toDataFrame")
@@ -142,32 +101,19 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- def toDataFrame(colName: String, colNames: String*): DataFrame = {
- val newNames = colName +: colNames
- require(schema.size == newNames.size,
- "The number of columns doesn't match.\n" +
- "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
- "New column names: " + newNames.mkString(", "))
-
- val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
- apply(oldName).as(newName)
- }
- select(newCols :_*)
- }
+ def toDataFrame(colName: String, colNames: String*): DataFrame
/** Returns the schema of this [[DataFrame]]. */
- override def schema: StructType = queryExecution.analyzed.schema
+ override def schema: StructType
/** Returns all column names and their data types as an array. */
- override def dtypes: Array[(String, String)] = schema.fields.map { field =>
- (field.name, field.dataType.toString)
- }
+ override def dtypes: Array[(String, String)]
/** Returns all column names as an array. */
override def columns: Array[String] = schema.fields.map(_.name)
/** Prints the schema to the console in a nice tree format. */
- override def printSchema(): Unit = println(schema.treeString)
+ override def printSchema(): Unit
/**
* Cartesian join with another [[DataFrame]].
@@ -176,9 +122,7 @@ class DataFrame protected[sql](
*
* @param right Right side of the join operation.
*/
- override def join(right: DataFrame): DataFrame = {
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
- }
+ override def join(right: DataFrame): DataFrame
/**
* Inner join with another [[DataFrame]], using the given join expression.
@@ -189,9 +133,7 @@ class DataFrame protected[sql](
* df1.join(df2).where($"df1Key" === $"df2Key")
* }}}
*/
- override def join(right: DataFrame, joinExprs: Column): DataFrame = {
- Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
- }
+ override def join(right: DataFrame, joinExprs: Column): DataFrame
/**
* Join with another [[DataFrame]], usin g the given join expression. The following performs
@@ -205,9 +147,7 @@ class DataFrame protected[sql](
* @param joinExprs Join expression.
* @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
*/
- override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
- Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
- }
+ override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
@@ -219,9 +159,7 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def sort(sortCol: String, sortCols: String*): DataFrame = {
- orderBy(apply(sortCol), sortCols.map(apply) :_*)
- }
+ override def sort(sortCol: String, sortCols: String*): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the given expressions. For example:
@@ -230,46 +168,26 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
- val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
- col.expr match {
- case expr: SortOrder =>
- expr
- case expr: Expression =>
- SortOrder(expr, Ascending)
- }
- }
- Sort(sortOrder, global = true, logicalPlan)
- }
+ override def sort(sortExpr: Column, sortExprs: Column*): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
- override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
- sort(sortCol, sortCols :_*)
- }
+ override def orderBy(sortCol: String, sortCols: String*): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
- override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
- sort(sortExpr, sortExprs :_*)
- }
+ override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
/**
* Selects column based on the column name and return it as a [[Column]].
*/
- override def apply(colName: String): Column = colName match {
- case "*" =>
- new Column(ResolvedStar(schema.fieldNames.map(resolve)))
- case _ =>
- val expr = resolve(colName)
- new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr)
- }
+ override def apply(colName: String): Column
/**
* Selects a set of expressions, wrapped in a Product.
@@ -279,18 +197,12 @@ class DataFrame protected[sql](
* df.select($"colA", $"colB" + 1)
* }}}
*/
- override def apply(projection: Product): DataFrame = {
- require(projection.productArity >= 1)
- select(projection.productIterator.map {
- case c: Column => c
- case o: Any => new Column(Some(sqlContext), None, Literal(o))
- }.toSeq :_*)
- }
+ override def apply(projection: Product): DataFrame
/**
* Returns a new [[DataFrame]] with an alias set.
*/
- override def as(name: String): DataFrame = Subquery(name, logicalPlan)
+ override def as(name: String): DataFrame
/**
* Selects a set of expressions.
@@ -299,15 +211,7 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def select(cols: Column*): DataFrame = {
- val exprs = cols.zipWithIndex.map {
- case (Column(expr: NamedExpression), _) =>
- expr
- case (Column(expr: Expression), _) =>
- Alias(expr, expr.toString)()
- }
- Project(exprs.toSeq, logicalPlan)
- }
+ override def select(cols: Column*): DataFrame
/**
* Selects a set of columns. This is a variant of `select` that can only select
@@ -320,9 +224,7 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def select(col: String, cols: String*): DataFrame = {
- select((col +: cols).map(new Column(_)) :_*)
- }
+ override def select(col: String, cols: String*): DataFrame
/**
* Filters rows using the given condition.
@@ -333,9 +235,7 @@ class DataFrame protected[sql](
* peopleDf($"age" > 15)
* }}}
*/
- override def filter(condition: Column): DataFrame = {
- Filter(condition.expr, logicalPlan)
- }
+ override def filter(condition: Column): DataFrame
/**
* Filters rows using the given condition. This is an alias for `filter`.
@@ -346,7 +246,7 @@ class DataFrame protected[sql](
* peopleDf($"age" > 15)
* }}}
*/
- override def where(condition: Column): DataFrame = filter(condition)
+ override def where(condition: Column): DataFrame
/**
* Filters rows using the given condition. This is a shorthand meant for Scala.
@@ -357,7 +257,7 @@ class DataFrame protected[sql](
* peopleDf($"age" > 15)
* }}}
*/
- override def apply(condition: Column): DataFrame = filter(condition)
+ override def apply(condition: Column): DataFrame
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -375,9 +275,7 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def groupBy(cols: Column*): GroupedDataFrame = {
- new GroupedDataFrame(this, cols.map(_.expr))
- }
+ override def groupBy(cols: Column*): GroupedDataFrame
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -398,10 +296,7 @@ class DataFrame protected[sql](
* }}}
*/
@scala.annotation.varargs
- override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
- val colNames: Seq[String] = col1 +: cols
- new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
- }
+ override def groupBy(col1: String, cols: String*): GroupedDataFrame
/**
* Aggregates on the entire [[DataFrame]] without groups.
@@ -411,7 +306,7 @@ class DataFrame protected[sql](
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
+ override def agg(exprs: Map[String, String]): DataFrame
/**
* Aggregates on the entire [[DataFrame]] without groups.
@@ -421,7 +316,7 @@ class DataFrame protected[sql](
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
- override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap)
+ override def agg(exprs: java.util.Map[String, String]): DataFrame
/**
* Aggregates on the entire [[DataFrame]] without groups.
@@ -432,31 +327,31 @@ class DataFrame protected[sql](
* }}
*/
@scala.annotation.varargs
- override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
+ override def agg(expr: Column, exprs: Column*): DataFrame
/**
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
* and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
*/
- override def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan)
+ override def limit(n: Int): DataFrame
/**
* Returns a new [[DataFrame]] containing union of rows in this frame and another frame.
* This is equivalent to `UNION ALL` in SQL.
*/
- override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
+ override def unionAll(other: DataFrame): DataFrame
/**
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
* This is equivalent to `INTERSECT` in SQL.
*/
- override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
+ override def intersect(other: DataFrame): DataFrame
/**
* Returns a new [[DataFrame]] containing rows in this frame but not in another frame.
* This is equivalent to `EXCEPT` in SQL.
*/
- override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
+ override def except(other: DataFrame): DataFrame
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows.
@@ -465,9 +360,7 @@ class DataFrame protected[sql](
* @param fraction Fraction of rows to generate.
* @param seed Seed for sampling.
*/
- override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
- }
+ override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
@@ -475,105 +368,85 @@ class DataFrame protected[sql](
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
*/
- override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
- sample(withReplacement, fraction, Utils.random.nextLong)
- }
+ override def sample(withReplacement: Boolean, fraction: Double): DataFrame
/////////////////////////////////////////////////////////////////////////////
/**
* Returns a new [[DataFrame]] by adding a column.
*/
- override def addColumn(colName: String, col: Column): DataFrame = {
- select(Column("*"), col.as(colName))
- }
+ override def addColumn(colName: String, col: Column): DataFrame
/**
* Returns the first `n` rows.
*/
- override def head(n: Int): Array[Row] = limit(n).collect()
+ override def head(n: Int): Array[Row]
/**
* Returns the first row.
*/
- override def head(): Row = head(1).head
+ override def head(): Row
/**
* Returns the first row. Alias for head().
*/
- override def first(): Row = head()
+ override def first(): Row
/**
* Returns a new RDD by applying a function to all rows of this DataFrame.
*/
- override def map[R: ClassTag](f: Row => R): RDD[R] = {
- rdd.map(f)
- }
+ override def map[R: ClassTag](f: Row => R): RDD[R]
/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
*/
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R]
/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
- rdd.mapPartitions(f)
- }
-
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R]
/**
* Applies a function `f` to all rows.
*/
- override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+ override def foreach(f: Row => Unit): Unit
/**
* Applies a function f to each partition of this [[DataFrame]].
*/
- override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit
/**
* Returns the first `n` rows in the [[DataFrame]].
*/
- override def take(n: Int): Array[Row] = head(n)
+ override def take(n: Int): Array[Row]
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
*/
- override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+ override def collect(): Array[Row]
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
*/
- override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
+ override def collectAsList(): java.util.List[Row]
/**
* Returns the number of rows in the [[DataFrame]].
*/
- override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
+ override def count(): Long
/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
*/
- override def repartition(numPartitions: Int): DataFrame = {
- sqlContext.applySchema(rdd.repartition(numPartitions), schema)
- }
+ override def repartition(numPartitions: Int): DataFrame
- override def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
- this
- }
+ override def persist(): this.type
- override def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
- this
- }
+ override def persist(newLevel: StorageLevel): this.type
- override def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
- this
- }
+ override def unpersist(blocking: Boolean): this.type
/////////////////////////////////////////////////////////////////////////////
// I/O
@@ -582,10 +455,7 @@ class DataFrame protected[sql](
/**
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
*/
- override def rdd: RDD[Row] = {
- val schema = this.schema
- queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
- }
+ override def rdd: RDD[Row]
/**
* Registers this RDD as a temporary table using the given name. The lifetime of this temporary
@@ -593,18 +463,14 @@ class DataFrame protected[sql](
*
* @group schema
*/
- override def registerTempTable(tableName: String): Unit = {
- sqlContext.registerRDDAsTable(this, tableName)
- }
+ override def registerTempTable(tableName: String): Unit
/**
* Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema.
* Files that are written out using this method can be read back in as a [[DataFrame]]
* using the `parquetFile` function in [[SQLContext]].
*/
- override def saveAsParquetFile(path: String): Unit = {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
+ override def saveAsParquetFile(path: String): Unit
/**
* :: Experimental ::
@@ -617,48 +483,26 @@ class DataFrame protected[sql](
* be the target of an `insertInto`.
*/
@Experimental
- override def saveAsTable(tableName: String): Unit = {
- sqlContext.executePlan(
- CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
- }
+ override def saveAsTable(tableName: String): Unit
/**
* :: Experimental ::
* Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
*/
@Experimental
- override def insertInto(tableName: String, overwrite: Boolean): Unit = {
- sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
- Map.empty, logicalPlan, overwrite)).toRdd
- }
+ override def insertInto(tableName: String, overwrite: Boolean): Unit
/**
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
*/
- override def toJSON: RDD[String] = {
- val rowSchema = this.schema
- this.mapPartitions { iter =>
- val jsonFactory = new JsonFactory()
- iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
- }
- }
+ override def toJSON: RDD[String]
////////////////////////////////////////////////////////////////////////////
// for Python API
////////////////////////////////////////////////////////////////////////////
- /**
- * A helpful function for Py4j, convert a list of Column to an array
- */
- protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = {
- cols.toList.toArray
- }
/**
* Converts a JavaRDD to a PythonRDD.
*/
- protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
- }
+ protected[sql] def javaToPython: JavaRDD[Array[Byte]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
new file mode 100644
index 0000000000..f8fcc25569
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -0,0 +1,331 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import java.util.{List => JList}
+
+import scala.language.implicitConversions
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
+import com.fasterxml.jackson.core.JsonFactory
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.util.Utils
+
+
+/**
+ * See [[DataFrame]] for documentation.
+ */
+private[sql] class DataFrameImpl protected[sql](
+ override val sqlContext: SQLContext,
+ val queryExecution: SQLContext#QueryExecution)
+ extends DataFrame {
+
+ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
+ this(sqlContext, {
+ val qe = sqlContext.executePlan(logicalPlan)
+ qe.analyzed // This should force analysis and throw errors if there are any
+ qe
+ })
+ }
+
+ @transient protected[sql] override val logicalPlan: LogicalPlan = queryExecution.logical match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ queryExecution.logical
+ }
+
+ /**
+ * An implicit conversion function internal to this class for us to avoid doing
+ * "new DataFrameImpl(...)" everywhere.
+ */
+ @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrameImpl(sqlContext, logicalPlan)
+ }
+
+ protected[sql] def resolve(colName: String): NamedExpression = {
+ queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
+ throw new RuntimeException(
+ s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
+ }
+ }
+
+ protected[sql] def numericColumns: Seq[Expression] = {
+ schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
+ queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
+ }
+ }
+
+ override def toDataFrame(colName: String, colNames: String*): DataFrame = {
+ val newNames = colName +: colNames
+ require(schema.size == newNames.size,
+ "The number of columns doesn't match.\n" +
+ "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ "New column names: " + newNames.mkString(", "))
+
+ val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
+ apply(oldName).as(newName)
+ }
+ select(newCols :_*)
+ }
+
+ override def schema: StructType = queryExecution.analyzed.schema
+
+ override def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
+
+ override def columns: Array[String] = schema.fields.map(_.name)
+
+ override def printSchema(): Unit = println(schema.treeString)
+
+ override def join(right: DataFrame): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ }
+
+ override def join(right: DataFrame, joinExprs: Column): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
+ }
+
+ override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
+ }
+
+ override def sort(sortCol: String, sortCols: String*): DataFrame = {
+ orderBy(apply(sortCol), sortCols.map(apply) :_*)
+ }
+
+ override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = true, logicalPlan)
+ }
+
+ override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
+ sort(sortCol, sortCols :_*)
+ }
+
+ override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
+ sort(sortExpr, sortExprs :_*)
+ }
+
+ override def apply(colName: String): Column = colName match {
+ case "*" =>
+ Column(ResolvedStar(schema.fieldNames.map(resolve)))
+ case _ =>
+ val expr = resolve(colName)
+ Column(sqlContext, Project(Seq(expr), logicalPlan), expr)
+ }
+
+ override def apply(projection: Product): DataFrame = {
+ require(projection.productArity >= 1)
+ select(projection.productIterator.map {
+ case c: Column => c
+ case o: Any => Column(Literal(o))
+ }.toSeq :_*)
+ }
+
+ override def as(name: String): DataFrame = Subquery(name, logicalPlan)
+
+ override def select(cols: Column*): DataFrame = {
+ val exprs = cols.zipWithIndex.map {
+ case (Column(expr: NamedExpression), _) =>
+ expr
+ case (Column(expr: Expression), _) =>
+ Alias(expr, expr.toString)()
+ }
+ Project(exprs.toSeq, logicalPlan)
+ }
+
+ override def select(col: String, cols: String*): DataFrame = {
+ select((col +: cols).map(Column(_)) :_*)
+ }
+
+ override def filter(condition: Column): DataFrame = {
+ Filter(condition.expr, logicalPlan)
+ }
+
+ override def where(condition: Column): DataFrame = {
+ filter(condition)
+ }
+
+ override def apply(condition: Column): DataFrame = {
+ filter(condition)
+ }
+
+ override def groupBy(cols: Column*): GroupedDataFrame = {
+ new GroupedDataFrame(this, cols.map(_.expr))
+ }
+
+ override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
+ val colNames: Seq[String] = col1 +: cols
+ new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
+ }
+
+ override def agg(exprs: Map[String, String]): DataFrame = {
+ groupBy().agg(exprs)
+ }
+
+ override def agg(exprs: java.util.Map[String, String]): DataFrame = {
+ agg(exprs.toMap)
+ }
+
+ override def agg(expr: Column, exprs: Column*): DataFrame = {
+ groupBy().agg(expr, exprs :_*)
+ }
+
+ override def limit(n: Int): DataFrame = {
+ Limit(Literal(n), logicalPlan)
+ }
+
+ override def unionAll(other: DataFrame): DataFrame = {
+ Union(logicalPlan, other.logicalPlan)
+ }
+
+ override def intersect(other: DataFrame): DataFrame = {
+ Intersect(logicalPlan, other.logicalPlan)
+ }
+
+ override def except(other: DataFrame): DataFrame = {
+ Except(logicalPlan, other.logicalPlan)
+ }
+
+ override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
+ Sample(fraction, withReplacement, seed, logicalPlan)
+ }
+
+ override def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ override def addColumn(colName: String, col: Column): DataFrame = {
+ select(Column("*"), col.as(colName))
+ }
+
+ override def head(n: Int): Array[Row] = limit(n).collect()
+
+ override def head(): Row = head(1).head
+
+ override def first(): Row = head()
+
+ override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
+
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
+
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
+
+ override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
+
+ override def take(n: Int): Array[Row] = head(n)
+
+ override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+
+ override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
+
+ override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
+
+ override def repartition(numPartitions: Int): DataFrame = {
+ sqlContext.applySchema(rdd.repartition(numPartitions), schema)
+ }
+
+ override def persist(): this.type = {
+ sqlContext.cacheManager.cacheQuery(this)
+ this
+ }
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+ this
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // I/O
+ /////////////////////////////////////////////////////////////////////////////
+
+ override def rdd: RDD[Row] = {
+ val schema = this.schema
+ queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
+ }
+
+ override def registerTempTable(tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(this, tableName)
+ }
+
+ override def saveAsParquetFile(path: String): Unit = {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+
+ override def saveAsTable(tableName: String): Unit = {
+ sqlContext.executePlan(
+ CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd
+ }
+
+ override def insertInto(tableName: String, overwrite: Boolean): Unit = {
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
+ }
+
+ override def toJSON: RDD[String] = {
+ val rowSchema = this.schema
+ this.mapPartitions { iter =>
+ val jsonFactory = new JsonFactory()
+ iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
+ }
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+ // for Python API
+ ////////////////////////////////////////////////////////////////////////////
+ protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ SerDeUtil.javaToPython(jrdd)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index 3499956023..b4279a32ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql
+import java.util.{List => JList}
+
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
+import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
@@ -55,17 +58,17 @@ object Dsl {
}
}
- private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
+ private[this] implicit def toColumn(expr: Expression): Column = Column(expr)
/**
* Returns a [[Column]] based on the given column name.
*/
- def col(colName: String): Column = new Column(colName)
+ def col(colName: String): Column = Column(colName)
/**
* Returns a [[Column]] based on the given column name. Alias of [[col]].
*/
- def column(colName: String): Column = new Column(colName)
+ def column(colName: String): Column = Column(colName)
/**
* Creates a [[Column]] of literal value.
@@ -94,7 +97,7 @@ object Dsl {
case _ =>
throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
}
- new Column(literalExpr)
+ Column(literalExpr)
}
def sum(e: Column): Column = Sum(e.expr)
@@ -105,8 +108,7 @@ object Dsl {
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
- def approxCountDistinct(e: Column): Column =
- ApproxCountDistinct(e.expr)
+ def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
def approxCountDistinct(e: Column, rsd: Double): Column =
ApproxCountDistinct(e.expr, rsd)
@@ -121,6 +123,13 @@ object Dsl {
def sqrt(e: Column): Column = Sqrt(e.expr)
def abs(e: Column): Column = Abs(e.expr)
+ /**
+ * This is a private API for Python
+ * TODO: move this to a private package
+ */
+ def toColumns(cols: JList[Column]): Seq[Column] = {
+ cols.toList.toSeq
+ }
// scalastyle:off
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
index 1c948cbbfe..d3acd41bbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.util.{List => JList}
+
import scala.language.implicitConversions
import scala.collection.JavaConversions._
@@ -28,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate
/**
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
-class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
+class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression])
extends GroupedDataFrameApi {
private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
@@ -36,8 +38,8 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
}
- new DataFrame(df.sqlContext,
- Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
+ DataFrame(
+ df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
}
private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
@@ -112,8 +114,7 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
}
-
- new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
+ DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
new file mode 100644
index 0000000000..2f8c695d56
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -0,0 +1,160 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.types.StructType
+
+
+private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column {
+
+ def this(name: String) = this(name match {
+ case "*" => UnresolvedStar(None)
+ case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
+ case _ => UnresolvedAttribute(name)
+ })
+
+ private def err[T](): T = {
+ throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn")
+ }
+
+ override def isComputable: Boolean = false
+
+ override val sqlContext: SQLContext = null
+
+ override def queryExecution = err()
+
+ protected[sql] override def logicalPlan: LogicalPlan = err()
+
+ override def toDataFrame(colName: String, colNames: String*): DataFrame = err()
+
+ override def schema: StructType = err()
+
+ override def dtypes: Array[(String, String)] = err()
+
+ override def columns: Array[String] = err()
+
+ override def printSchema(): Unit = err()
+
+ override def join(right: DataFrame): DataFrame = err()
+
+ override def join(right: DataFrame, joinExprs: Column): DataFrame = err()
+
+ override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = err()
+
+ override def sort(sortCol: String, sortCols: String*): DataFrame = err()
+
+ override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+
+ override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
+
+ override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+
+ override def apply(colName: String): Column = err()
+
+ override def apply(projection: Product): DataFrame = err()
+
+ override def select(cols: Column*): DataFrame = err()
+
+ override def select(col: String, cols: String*): DataFrame = err()
+
+ override def filter(condition: Column): DataFrame = err()
+
+ override def where(condition: Column): DataFrame = err()
+
+ override def apply(condition: Column): DataFrame = err()
+
+ override def groupBy(cols: Column*): GroupedDataFrame = err()
+
+ override def groupBy(col1: String, cols: String*): GroupedDataFrame = err()
+
+ override def agg(exprs: Map[String, String]): DataFrame = err()
+
+ override def agg(exprs: java.util.Map[String, String]): DataFrame = err()
+
+ override def agg(expr: Column, exprs: Column*): DataFrame = err()
+
+ override def limit(n: Int): DataFrame = err()
+
+ override def unionAll(other: DataFrame): DataFrame = err()
+
+ override def intersect(other: DataFrame): DataFrame = err()
+
+ override def except(other: DataFrame): DataFrame = err()
+
+ override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
+
+ override def sample(withReplacement: Boolean, fraction: Double): DataFrame = err()
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ override def addColumn(colName: String, col: Column): DataFrame = err()
+
+ override def head(n: Int): Array[Row] = err()
+
+ override def head(): Row = err()
+
+ override def first(): Row = err()
+
+ override def map[R: ClassTag](f: Row => R): RDD[R] = err()
+
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = err()
+
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = err()
+
+ override def foreach(f: Row => Unit): Unit = err()
+
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = err()
+
+ override def take(n: Int): Array[Row] = err()
+
+ override def collect(): Array[Row] = err()
+
+ override def collectAsList(): java.util.List[Row] = err()
+
+ override def count(): Long = err()
+
+ override def repartition(numPartitions: Int): DataFrame = err()
+
+ override def persist(): this.type = err()
+
+ override def persist(newLevel: StorageLevel): this.type = err()
+
+ override def unpersist(blocking: Boolean): this.type = err()
+
+ override def rdd: RDD[Row] = err()
+
+ override def registerTempTable(tableName: String): Unit = err()
+
+ override def saveAsParquetFile(path: String): Unit = err()
+
+ override def saveAsTable(tableName: String): Unit = err()
+
+ override def insertInto(tableName: String, overwrite: Boolean): Unit = err()
+
+ override def toJSON: RDD[String] = err()
+
+ protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 84933dd944..d0bbb5f7a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -171,14 +171,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributeSeq = schema.toAttributes
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
- new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
+ DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self))
}
/**
* Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
*/
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
- new DataFrame(this, LogicalRelation(baseRelation))
+ DataFrame(this, LogicalRelation(baseRelation))
}
/**
@@ -216,7 +216,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
- new DataFrame(this, logicalPlan)
+ DataFrame(this, logicalPlan)
}
/**
@@ -243,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
) : Row
}
}
- new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
+ DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
}
/**
@@ -262,7 +262,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def parquetFile(path: String): DataFrame =
- new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
+ DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
* Loads a JSON file (one object per line), returning the result as a [[DataFrame]].
@@ -365,7 +365,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
- new DataFrame(this, parseSql(sqlText))
+ DataFrame(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
@@ -373,7 +373,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** Returns the specified table as a [[DataFrame]]. */
def table(tableName: String): DataFrame =
- new DataFrame(this, catalog.lookupRelation(Seq(tableName)))
+ DataFrame(this, catalog.lookupRelation(Seq(tableName)))
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
@@ -462,7 +462,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* access to the intermediate phases of query execution for developers.
*/
@DeveloperApi
- protected class QueryExecution(val logical: LogicalPlan) {
+ protected[sql] class QueryExecution(val logical: LogicalPlan) {
lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical))
lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
@@ -556,7 +556,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
iter.map { m => new GenericRow(m): Row}
}
- new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
+ DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 6fba76c521..e1c9a2be7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -138,7 +138,7 @@ case class CacheTableCommand(
override def run(sqlContext: SQLContext) = {
plan.foreach { logicalPlan =>
- sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName)
+ sqlContext.registerRDDAsTable(DataFrame(sqlContext, logicalPlan), tableName)
}
sqlContext.cacheTable(tableName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index b7c721f8c0..b1bbe0f89a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -231,7 +231,7 @@ private [sql] case class CreateTempTableUsing(
def run(sqlContext: SQLContext) = {
val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
sqlContext.registerRDDAsTable(
- new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
+ DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 906455dd40..4e1ec38bd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -41,7 +41,7 @@ object TestSQLContext
* construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
- new DataFrame(this, plan)
+ DataFrame(this, plan)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 2d464c2b53..fa4cdecbcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -27,6 +27,45 @@ class ColumnExpressionSuite extends QueryTest {
// TODO: Add test cases for bitwise operations.
+ test("computability check") {
+ def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true)
+
+ def shouldNotBeComputable(c: Column): Unit = {
+ assert(c.isComputable === false)
+ intercept[UnsupportedOperationException] { c.head() }
+ }
+
+ shouldBeComputable(testData2("a"))
+ shouldBeComputable(testData2("b"))
+
+ shouldBeComputable(testData2("a") + testData2("b"))
+ shouldBeComputable(testData2("a") + testData2("b") + 1)
+
+ shouldBeComputable(-testData2("a"))
+ shouldBeComputable(!testData2("a"))
+
+ shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
+ shouldBeComputable(
+ testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
+ shouldBeComputable(
+ testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
+
+ // Literals and unresolved columns should not be computable.
+ shouldNotBeComputable(col("1"))
+ shouldNotBeComputable(col("1") + 2)
+ shouldNotBeComputable(lit(100))
+ shouldNotBeComputable(lit(100) + 10)
+ shouldNotBeComputable(-col("1"))
+ shouldNotBeComputable(!col("1"))
+
+ // Getting data from different frames should not be computable.
+ shouldNotBeComputable(testData2("a") + testData("key"))
+ shouldNotBeComputable(testData2("a") + 1 + testData("key"))
+
+ // Aggregate functions alone should not be computable.
+ shouldNotBeComputable(sum(testData2("a")))
+ }
+
test("star") {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index df343adc79..f6b65a81ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -28,6 +28,19 @@ import scala.language.postfixOps
class DataFrameSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ test("analysis error should be eagerly reported") {
+ intercept[Exception] { testData.select('nonExistentName) }
+ intercept[Exception] {
+ testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
+ }
+ intercept[Exception] {
+ testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
+ }
+ intercept[Exception] {
+ testData.groupBy($"abcd").agg(Map("key" -> "sum"))
+ }
+ }
+
test("table scan") {
checkAnswer(
testData,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d82c34316c..e18ba287e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -807,13 +807,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("throw errors for non-aggregate attributes with aggregation") {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
- val logicalPlan = sql(query).queryExecution.logical
-
if (isInvalidQuery) {
val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
assert(
e.getMessage.startsWith("Expression not in GROUP BY"),
- "Non-aggregate attribute(s) not detected\n" + logicalPlan)
+ "Non-aggregate attribute(s) not detected\n")
} else {
// Should not throw
sql(query).queryExecution.analyzed
@@ -821,7 +819,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
checkAggregation("SELECT key, COUNT(*) FROM testData")
- checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
+ checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", isInvalidQuery = false)
checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index e78145f4dd..ff91a0eb42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row}
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -51,8 +51,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
val query = rdd
- .select(output.map(e => new org.apache.spark.sql.Column(e)): _*)
- .where(new org.apache.spark.sql.Column(predicate))
+ .select(output.map(e => Column(e)): _*)
+ .where(Column(predicate))
val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect {
case plan: ParquetTableScan => plan.columnPruningPred
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index b746942cb1..5efc3b1e30 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -72,7 +72,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted)))
+ DataFrame(this,
+ ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 83244ce1e3..fa997288a2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.hive
+import org.apache.spark.sql.catalyst.expressions.Row
+
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
@@ -29,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.hive
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.sources.CreateTableUsing
@@ -56,14 +57,14 @@ private[hive] trait HiveStrategies {
@Experimental
object ParquetConversion extends Strategy {
implicit class LogicalPlanHacks(s: DataFrame) {
- def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan)
+ def lowerCase = DataFrame(s.sqlContext, s.logicalPlan)
def addPartitioningAttributes(attrs: Seq[Attribute]) = {
// Don't add the partitioning key if its already present in the data.
if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) {
s
} else {
- new DataFrame(
+ DataFrame(
s.sqlContext,
s.logicalPlan transform {
case p: ParquetRelation => p.copy(partitioningAttributes = attrs)
@@ -96,13 +97,13 @@ private[hive] trait HiveStrategies {
// We are going to throw the predicates and projection back at the whole optimization
// sequence so lets unresolve all the attributes, allowing them to be rebound to the
// matching parquet attributes.
- val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform {
+ val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
}).reduceOption(And).getOrElse(Literal(true)))
val unresolvedProjection: Seq[Column] = projectList.map(_ transform {
case a: AttributeReference => UnresolvedAttribute(a.name)
- }).map(new Column(_))
+ }).map(Column(_))
try {
if (relation.hiveQlTable.isPartitioned) {