aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-19 12:09:44 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-19 12:09:44 -0800
commit8ca3418e1b3e2687e75a08c185d17045a97279fb (patch)
treedf3de8114cedc6d02a4e656734c865ce5c1e1cb7 /sql/core
parent94cdb05ff7e6b8fc5b3a574202ba8bc8e5bbe689 (diff)
downloadspark-8ca3418e1b3e2687e75a08c185d17045a97279fb.tar.gz
spark-8ca3418e1b3e2687e75a08c185d17045a97279fb.tar.bz2
spark-8ca3418e1b3e2687e75a08c185d17045a97279fb.zip
[SPARK-5904][SQL] DataFrame API fixes.
1. Column is no longer a DataFrame to simplify class hierarchy. 2. Don't use varargs on abstract methods (see Scala compiler bug SI-9013). Author: Reynold Xin <rxin@databricks.com> Closes #4686 from rxin/SPARK-5904 and squashes the following commits: fd9b199 [Reynold Xin] Fixed Python tests. df25cef [Reynold Xin] Non final. 5221530 [Reynold Xin] [SPARK-5904][SQL] DataFrame API fixes.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala223
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala420
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala483
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala183
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
8 files changed, 407 insertions, 988 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8b6241c213..980754322e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -22,20 +22,15 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
import org.apache.spark.sql.types._
private[sql] object Column {
- def apply(colName: String): Column = new IncomputableColumn(colName)
+ def apply(colName: String): Column = new Column(colName)
- def apply(expr: Expression): Column = new IncomputableColumn(expr)
-
- def apply(sqlContext: SQLContext, plan: LogicalPlan, expr: Expression): Column = {
- new ComputableColumn(sqlContext, plan, expr)
- }
+ def apply(expr: Expression): Column = new Column(expr)
def unapply(col: Column): Option[Expression] = Some(col.expr)
}
@@ -51,68 +46,18 @@ private[sql] object Column {
* @groupname Ungrouped Support functions for DataFrames.
*/
@Experimental
-trait Column extends DataFrame {
-
- protected[sql] def expr: Expression
-
- /**
- * Returns true iff the [[Column]] is computable.
- */
- def isComputable: Boolean
-
- /** Removes the top project so we can get to the underlying plan. */
- private def stripProject(p: LogicalPlan): LogicalPlan = p match {
- case Project(_, child) => child
- case p => sys.error("Unexpected logical plan (expected Project): " + p)
- }
-
- private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val namedExpr = expr match {
- case named: NamedExpression => named
- case unnamed: Expression => Alias(unnamed, "col")()
- }
- val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
- Column(baseCol.sqlContext, plan, expr)
- }
+class Column(protected[sql] val expr: Expression) {
- /**
- * Construct a new column based on the expression and the other column value.
- *
- * There are two cases that can happen here:
- * If otherValue is a constant, it is first turned into a Column.
- * If otherValue is a Column, then:
- * - If this column and otherValue are both computable and come from the same logical plan,
- * then we can construct a ComputableColumn by applying a Project on top of the base plan.
- * - If this column is not computable, but otherValue is computable, then we can construct
- * a ComputableColumn based on otherValue's base plan.
- * - If this column is computable, but otherValue is not, then we can construct a
- * ComputableColumn based on this column's base plan.
- * - If neither columns are computable, then we create an IncomputableColumn.
- */
- private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
- // lit(otherValue) returns a Column always.
- (this, lit(otherValue)) match {
- case (left: ComputableColumn, right: ComputableColumn) =>
- if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
- computableCol(right, newExpr(right))
- } else {
- // We don't want to throw an exception here because "df1("a") === df2("b")" can be
- // a valid expression for join conditions, even though standalone they are not valid.
- Column(newExpr(right))
- }
- case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
- case (_, right: ComputableColumn) => computableCol(right, newExpr(right))
- case (_, right) => Column(newExpr(right))
- }
- }
+ def this(name: String) = this(name match {
+ case "*" => UnresolvedStar(None)
+ case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
+ case _ => UnresolvedAttribute(name)
+ })
/** Creates a column based on the given expression. */
- private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = {
- this match {
- case c: ComputableColumn if computable => computableCol(c, newExpr)
- case _ => Column(newExpr)
- }
- }
+ implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr)
+
+ override def toString: String = expr.prettyString
/**
* Unary minus, i.e. negate the expression.
@@ -127,7 +72,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def unary_- : Column = exprToColumn(UnaryMinus(expr))
+ def unary_- : Column = UnaryMinus(expr)
/**
* Inversion of boolean expression, i.e. NOT.
@@ -142,7 +87,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def unary_! : Column = exprToColumn(Not(expr))
+ def unary_! : Column = Not(expr)
/**
* Equality test.
@@ -157,9 +102,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def === (other: Any): Column = constructColumn(other) { o =>
- EqualTo(expr, o.expr)
- }
+ def === (other: Any): Column = EqualTo(expr, lit(other).expr)
/**
* Equality test.
@@ -190,9 +133,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def !== (other: Any): Column = constructColumn(other) { o =>
- Not(EqualTo(expr, o.expr))
- }
+ def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr))
/**
* Inequality test.
@@ -208,9 +149,7 @@ trait Column extends DataFrame {
*
* @group java_expr_ops
*/
- def notEqual(other: Any): Column = constructColumn(other) { o =>
- Not(EqualTo(expr, o.expr))
- }
+ def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr))
/**
* Greater than.
@@ -225,9 +164,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def > (other: Any): Column = constructColumn(other) { o =>
- GreaterThan(expr, o.expr)
- }
+ def > (other: Any): Column = GreaterThan(expr, lit(other).expr)
/**
* Greater than.
@@ -256,9 +193,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def < (other: Any): Column = constructColumn(other) { o =>
- LessThan(expr, o.expr)
- }
+ def < (other: Any): Column = LessThan(expr, lit(other).expr)
/**
* Less than.
@@ -286,9 +221,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def <= (other: Any): Column = constructColumn(other) { o =>
- LessThanOrEqual(expr, o.expr)
- }
+ def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr)
/**
* Less than or equal to.
@@ -316,9 +249,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def >= (other: Any): Column = constructColumn(other) { o =>
- GreaterThanOrEqual(expr, o.expr)
- }
+ def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr)
/**
* Greater than or equal to an expression.
@@ -339,9 +270,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def <=> (other: Any): Column = constructColumn(other) { o =>
- EqualNullSafe(expr, o.expr)
- }
+ def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr)
/**
* Equality test that is safe for null values.
@@ -355,14 +284,14 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def isNull: Column = exprToColumn(IsNull(expr))
+ def isNull: Column = IsNull(expr)
/**
* True if the current expression is NOT null.
*
* @group expr_ops
*/
- def isNotNull: Column = exprToColumn(IsNotNull(expr))
+ def isNotNull: Column = IsNotNull(expr)
/**
* Boolean OR.
@@ -376,9 +305,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def || (other: Any): Column = constructColumn(other) { o =>
- Or(expr, o.expr)
- }
+ def || (other: Any): Column = Or(expr, lit(other).expr)
/**
* Boolean OR.
@@ -406,9 +333,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def && (other: Any): Column = constructColumn(other) { o =>
- And(expr, o.expr)
- }
+ def && (other: Any): Column = And(expr, lit(other).expr)
/**
* Boolean AND.
@@ -436,9 +361,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def + (other: Any): Column = constructColumn(other) { o =>
- Add(expr, o.expr)
- }
+ def + (other: Any): Column = Add(expr, lit(other).expr)
/**
* Sum of this expression and another expression.
@@ -466,9 +389,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def - (other: Any): Column = constructColumn(other) { o =>
- Subtract(expr, o.expr)
- }
+ def - (other: Any): Column = Subtract(expr, lit(other).expr)
/**
* Subtraction. Subtract the other expression from this expression.
@@ -496,9 +417,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def * (other: Any): Column = constructColumn(other) { o =>
- Multiply(expr, o.expr)
- }
+ def * (other: Any): Column = Multiply(expr, lit(other).expr)
/**
* Multiplication of this expression and another expression.
@@ -526,9 +445,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def / (other: Any): Column = constructColumn(other) { o =>
- Divide(expr, o.expr)
- }
+ def / (other: Any): Column = Divide(expr, lit(other).expr)
/**
* Division this expression by another expression.
@@ -549,9 +466,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def % (other: Any): Column = constructColumn(other) { o =>
- Remainder(expr, o.expr)
- }
+ def % (other: Any): Column = Remainder(expr, lit(other).expr)
/**
* Modulo (a.k.a. remainder) expression.
@@ -567,37 +482,35 @@ trait Column extends DataFrame {
* @group expr_ops
*/
@scala.annotation.varargs
- def in(list: Column*): Column = {
- new IncomputableColumn(In(expr, list.map(_.expr)))
- }
+ def in(list: Column*): Column = In(expr, list.map(_.expr))
/**
* SQL like expression.
*
* @group expr_ops
*/
- def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr))
+ def like(literal: String): Column = Like(expr, lit(literal).expr)
/**
* SQL RLIKE expression (LIKE with Regex).
*
* @group expr_ops
*/
- def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr))
+ def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
* An expression that gets an item at position `ordinal` out of an array.
*
* @group expr_ops
*/
- def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal)))
+ def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
/**
* An expression that gets a field by name in a [[StructField]].
*
* @group expr_ops
*/
- def getField(fieldName: String): Column = exprToColumn(UnresolvedGetField(expr, fieldName))
+ def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
/**
* An expression that returns a substring.
@@ -606,8 +519,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def substr(startPos: Column, len: Column): Column =
- exprToColumn(Substring(expr, startPos.expr, len.expr), computable = false)
+ def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr)
/**
* An expression that returns a substring.
@@ -616,26 +528,21 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def substr(startPos: Int, len: Int): Column =
- exprToColumn(Substring(expr, lit(startPos).expr, lit(len).expr))
+ def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr)
/**
* Contains the other element.
*
* @group expr_ops
*/
- def contains(other: Any): Column = constructColumn(other) { o =>
- Contains(expr, o.expr)
- }
+ def contains(other: Any): Column = Contains(expr, lit(other).expr)
/**
* String starts with.
*
* @group expr_ops
*/
- def startsWith(other: Column): Column = constructColumn(other) { o =>
- StartsWith(expr, o.expr)
- }
+ def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr)
/**
* String starts with another string literal.
@@ -649,9 +556,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def endsWith(other: Column): Column = constructColumn(other) { o =>
- EndsWith(expr, o.expr)
- }
+ def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr)
/**
* String ends with another string literal.
@@ -669,7 +574,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- override def as(alias: String): Column = exprToColumn(Alias(expr, alias)())
+ def as(alias: String): Column = Alias(expr, alias)()
/**
* Gives the column an alias.
@@ -680,7 +585,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- override def as(alias: Symbol): Column = exprToColumn(Alias(expr, alias.name)())
+ def as(alias: Symbol): Column = Alias(expr, alias.name)()
/**
* Casts the column to a different data type.
@@ -695,7 +600,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def cast(to: DataType): Column = exprToColumn(Cast(expr, to))
+ def cast(to: DataType): Column = Cast(expr, to)
/**
* Casts the column to a different data type, using the canonical string representation
@@ -708,22 +613,20 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def cast(to: String): Column = exprToColumn(
- Cast(expr, to.toLowerCase match {
- case "string" => StringType
- case "boolean" => BooleanType
- case "byte" => ByteType
- case "short" => ShortType
- case "int" => IntegerType
- case "long" => LongType
- case "float" => FloatType
- case "double" => DoubleType
- case "decimal" => DecimalType.Unlimited
- case "date" => DateType
- case "timestamp" => TimestampType
- case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
- })
- )
+ def cast(to: String): Column = Cast(expr, to.toLowerCase match {
+ case "string" | "str" => StringType
+ case "boolean" => BooleanType
+ case "byte" => ByteType
+ case "short" => ShortType
+ case "int" => IntegerType
+ case "long" => LongType
+ case "float" => FloatType
+ case "double" => DoubleType
+ case "decimal" => DecimalType.Unlimited
+ case "date" => DateType
+ case "timestamp" => TimestampType
+ case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
+ })
/**
* Returns an ordering used in sorting.
@@ -737,7 +640,7 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false)
+ def desc: Column = SortOrder(expr, Descending)
/**
* Returns an ordering used in sorting.
@@ -751,14 +654,14 @@ trait Column extends DataFrame {
*
* @group expr_ops
*/
- def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false)
+ def asc: Column = SortOrder(expr, Ascending)
/**
- * Prints the plans (logical and physical) to the console for debugging purpose.
+ * Prints the expression to the console for debugging purpose.
*
* @group df_ops
*/
- override def explain(extended: Boolean): Unit = {
+ def explain(extended: Boolean): Unit = {
if (extended) {
println(expr)
} else {
@@ -768,7 +671,7 @@ trait Column extends DataFrame {
}
-class ColumnName(name: String) extends IncomputableColumn(name) {
+class ColumnName(name: String) extends Column(name) {
/** Creates a new AttributeReference of type boolean */
def boolean: StructField = StructField(name, BooleanType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
deleted file mode 100644
index ac479b26a7..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import scala.language.implicitConversions
-
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-
-private[sql] class ComputableColumn protected[sql](
- sqlContext: SQLContext,
- protected[sql] val plan: LogicalPlan,
- protected[sql] val expr: Expression)
- extends DataFrameImpl(sqlContext, plan) with Column {
-
- override def isComputable: Boolean = true
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 5007a5a34d..810f7c7747 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -17,26 +17,38 @@
package org.apache.spark.sql
+import java.io.CharArrayWriter
import java.sql.DriverManager
-
import scala.collection.JavaConversions._
+import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
+import com.fasterxml.jackson.core.JsonFactory
+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
+
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
- new DataFrameImpl(sqlContext, logicalPlan)
+ new DataFrame(sqlContext, logicalPlan)
}
}
@@ -90,22 +102,100 @@ private[sql] object DataFrame {
*/
// TODO: Improve documentation.
@Experimental
-trait DataFrame extends RDDApi[Row] with Serializable {
+class DataFrame protected[sql](
+ @transient val sqlContext: SQLContext,
+ @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution)
+ extends RDDApi[Row] with Serializable {
+
+ /**
+ * A constructor that automatically analyzes the logical plan.
+ *
+ * This reports error eagerly as the [[DataFrame]] is constructed, unless
+ * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
+ */
+ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
+ this(sqlContext, {
+ val qe = sqlContext.executePlan(logicalPlan)
+ if (sqlContext.conf.dataFrameEagerAnalysis) {
+ qe.analyzed // This should force analysis and throw errors if there are any
+ }
+ qe
+ })
+ }
+
+ @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command |
+ _: InsertIntoTable |
+ _: CreateTableAsSelect[_] |
+ _: CreateTableUsingAsSelect |
+ _: WriteToFile =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ queryExecution.logical
+ }
+
+ /**
+ * An implicit conversion function internal to this class for us to avoid doing
+ * "new DataFrameImpl(...)" everywhere.
+ */
+ @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrame(sqlContext, logicalPlan)
+ }
- val sqlContext: SQLContext
+ protected[sql] def resolve(colName: String): NamedExpression = {
+ queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
+ throw new AnalysisException(
+ s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
+ }
+ }
- @DeveloperApi
- def queryExecution: SQLContext#QueryExecution
+ protected[sql] def numericColumns: Seq[Expression] = {
+ schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
+ queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
+ }
+ }
- protected[sql] def logicalPlan: LogicalPlan
+ /**
+ * Internal API for Python
+ */
+ private[sql] def showString(): String = {
+ val data = take(20)
+ val numCols = schema.fieldNames.length
- override def toString =
+ // For cells that are beyond 20 characters, replace it with the first 17 and "..."
+ val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ row.toSeq.map { cell =>
+ val str = if (cell == null) "null" else cell.toString
+ if (str.length > 20) str.substring(0, 17) + "..." else str
+ }: Seq[String]
+ }
+
+ // Compute the width of each column
+ val colWidths = Array.fill(numCols)(0)
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Pad the cells
+ rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ String.format(s"%-${colWidths(i)}s", cell)
+ }.mkString(" ")
+ }.mkString("\n")
+ }
+
+ override def toString: String = {
try {
schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
} catch {
case NonFatal(e) =>
s"Invalid tree; ${e.getMessage}:\n$queryExecution"
}
+ }
/** Left here for backward compatibility. */
@deprecated("1.3.0", "use toDF")
@@ -130,19 +220,31 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group basic
*/
@scala.annotation.varargs
- def toDF(colNames: String*): DataFrame
+ def toDF(colNames: String*): DataFrame = {
+ require(schema.size == colNames.size,
+ "The number of columns doesn't match.\n" +
+ "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ "New column names: " + colNames.mkString(", "))
+
+ val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
+ apply(oldName).as(newName)
+ }
+ select(newCols :_*)
+ }
/**
* Returns the schema of this [[DataFrame]].
* @group basic
*/
- def schema: StructType
+ def schema: StructType = queryExecution.analyzed.schema
/**
* Returns all column names and their data types as an array.
* @group basic
*/
- def dtypes: Array[(String, String)]
+ def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
/**
* Returns all column names as an array.
@@ -154,13 +256,19 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Prints the schema to the console in a nice tree format.
* @group basic
*/
- def printSchema(): Unit
+ def printSchema(): Unit = println(schema.treeString)
/**
* Prints the plans (logical and physical) to the console for debugging purpose.
* @group basic
*/
- def explain(extended: Boolean): Unit
+ def explain(extended: Boolean): Unit = {
+ ExplainCommand(
+ logicalPlan,
+ extended = extended).queryExecution.executedPlan.executeCollect().map {
+ r => println(r.getString(0))
+ }
+ }
/**
* Only prints the physical plan to the console for debugging purpose.
@@ -173,7 +281,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* (without any Spark executors).
* @group basic
*/
- def isLocal: Boolean
+ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
/**
* Displays the [[DataFrame]] in a tabular form. For example:
@@ -187,7 +295,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group basic
*/
- def show(): Unit
+ def show(): Unit = println(showString())
/**
* Cartesian join with another [[DataFrame]].
@@ -197,7 +305,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @param right Right side of the join operation.
* @group dfops
*/
- def join(right: DataFrame): DataFrame
+ def join(right: DataFrame): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ }
/**
* Inner join with another [[DataFrame]], using the given join expression.
@@ -209,7 +319,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def join(right: DataFrame, joinExprs: Column): DataFrame
+ def join(right: DataFrame, joinExprs: Column): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, Some(joinExprs.expr))
+ }
/**
* Join with another [[DataFrame]], using the given join expression. The following performs
@@ -230,7 +342,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
* @group dfops
*/
- def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame
+ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
+ }
/**
* Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
@@ -243,7 +357,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def sort(sortCol: String, sortCols: String*): DataFrame
+ def sort(sortCol: String, sortCols: String*): DataFrame = {
+ sort((sortCol +: sortCols).map(apply) :_*)
+ }
/**
* Returns a new [[DataFrame]] sorted by the given expressions. For example:
@@ -253,7 +369,17 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def sort(sortExprs: Column*): DataFrame
+ def sort(sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = true, logicalPlan)
+ }
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@@ -261,7 +387,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def orderBy(sortCol: String, sortCols: String*): DataFrame
+ def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*)
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@@ -269,7 +395,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def orderBy(sortExprs: Column*): DataFrame
+ def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*)
/**
* Selects column based on the column name and return it as a [[Column]].
@@ -281,19 +407,25 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Selects column based on the column name and return it as a [[Column]].
* @group dfops
*/
- def col(colName: String): Column
+ def col(colName: String): Column = colName match {
+ case "*" =>
+ Column(ResolvedStar(schema.fieldNames.map(resolve)))
+ case _ =>
+ val expr = resolve(colName)
+ Column(expr)
+ }
/**
* Returns a new [[DataFrame]] with an alias set.
* @group dfops
*/
- def as(alias: String): DataFrame
+ def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
/**
* (Scala-specific) Returns a new [[DataFrame]] with an alias set.
* @group dfops
*/
- def as(alias: Symbol): DataFrame
+ def as(alias: Symbol): DataFrame = as(alias.name)
/**
* Selects a set of expressions.
@@ -303,7 +435,13 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def select(cols: Column*): DataFrame
+ def select(cols: Column*): DataFrame = {
+ val namedExpressions = cols.map {
+ case Column(expr: NamedExpression) => expr
+ case Column(expr: Expression) => Alias(expr, expr.prettyString)()
+ }
+ Project(namedExpressions.toSeq, logicalPlan)
+ }
/**
* Selects a set of columns. This is a variant of `select` that can only select
@@ -317,7 +455,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def select(col: String, cols: String*): DataFrame
+ def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*)
/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
@@ -329,7 +467,11 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def selectExpr(exprs: String*): DataFrame
+ def selectExpr(exprs: String*): DataFrame = {
+ select(exprs.map { expr =>
+ Column(new SqlParser().parseExpression(expr))
+ }: _*)
+ }
/**
* Filters rows using the given condition.
@@ -341,7 +483,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def filter(condition: Column): DataFrame
+ def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan)
/**
* Filters rows using the given SQL expression.
@@ -350,7 +492,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def filter(conditionExpr: String): DataFrame
+ def filter(conditionExpr: String): DataFrame = {
+ filter(Column(new SqlParser().parseExpression(conditionExpr)))
+ }
/**
* Filters rows using the given condition. This is an alias for `filter`.
@@ -362,7 +506,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def where(condition: Column): DataFrame
+ def where(condition: Column): DataFrame = filter(condition)
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -381,7 +525,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def groupBy(cols: Column*): GroupedData
+ def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -403,7 +547,10 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group dfops
*/
@scala.annotation.varargs
- def groupBy(col1: String, cols: String*): GroupedData
+ def groupBy(col1: String, cols: String*): GroupedData = {
+ val colNames: Seq[String] = col1 +: cols
+ new GroupedData(this, colNames.map(colName => resolve(colName)))
+ }
/**
* (Scala-specific) Compute aggregates by specifying a map from column name to
@@ -462,28 +609,28 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
* @group dfops
*/
- def limit(n: Int): DataFrame
+ def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan)
/**
* Returns a new [[DataFrame]] containing union of rows in this frame and another frame.
* This is equivalent to `UNION ALL` in SQL.
* @group dfops
*/
- def unionAll(other: DataFrame): DataFrame
+ def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
* This is equivalent to `INTERSECT` in SQL.
* @group dfops
*/
- def intersect(other: DataFrame): DataFrame
+ def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] containing rows in this frame but not in another frame.
* This is equivalent to `EXCEPT` in SQL.
* @group dfops
*/
- def except(other: DataFrame): DataFrame
+ def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows.
@@ -493,7 +640,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @param seed Seed for sampling.
* @group dfops
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
+ Sample(fraction, withReplacement, seed, logicalPlan)
+ }
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
@@ -527,8 +676,15 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame
+ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributes = schema.toAttributes
+ val rowFunction =
+ f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
+ val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+ }
/**
* (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
@@ -540,10 +696,17 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* }}}
* @group dfops
*/
- def explode[A, B : TypeTag](
- inputColumn: String,
- outputColumn: String)(
- f: A => TraversableOnce[B]): DataFrame
+ def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B])
+ : DataFrame = {
+ val dataType = ScalaReflection.schemaFor[B].dataType
+ val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+ def rowFunction(row: Row) = {
+ f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
+ }
+ val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
+
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+ }
/////////////////////////////////////////////////////////////////////////////
@@ -551,110 +714,130 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Returns a new [[DataFrame]] by adding a column.
* @group dfops
*/
- def withColumn(colName: String, col: Column): DataFrame
+ def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName))
/**
* Returns a new [[DataFrame]] with a column renamed.
* @group dfops
*/
- def withColumnRenamed(existingName: String, newName: String): DataFrame
+ def withColumnRenamed(existingName: String, newName: String): DataFrame = {
+ val resolver = sqlContext.analyzer.resolver
+ val colNames = schema.map { field =>
+ val name = field.name
+ if (resolver(name, existingName)) Column(name).as(newName) else Column(name)
+ }
+ select(colNames :_*)
+ }
/**
* Returns the first `n` rows.
*/
- def head(n: Int): Array[Row]
+ def head(n: Int): Array[Row] = limit(n).collect()
/**
* Returns the first row.
*/
- def head(): Row
+ def head(): Row = head(1).head
/**
* Returns the first row. Alias for head().
*/
- override def first(): Row
+ override def first(): Row = head()
/**
* Returns a new RDD by applying a function to all rows of this DataFrame.
* @group rdd
*/
- override def map[R: ClassTag](f: Row => R): RDD[R]
+ override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
* @group rdd
*/
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R]
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
* @group rdd
*/
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R]
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
/**
* Applies a function `f` to all rows.
* @group rdd
*/
- override def foreach(f: Row => Unit): Unit
+ override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
/**
* Applies a function f to each partition of this [[DataFrame]].
* @group rdd
*/
- override def foreachPartition(f: Iterator[Row] => Unit): Unit
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
/**
* Returns the first `n` rows in the [[DataFrame]].
* @group action
*/
- override def take(n: Int): Array[Row]
+ override def take(n: Int): Array[Row] = head(n)
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
* @group action
*/
- override def collect(): Array[Row]
+ override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
* @group action
*/
- override def collectAsList(): java.util.List[Row]
+ override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
/**
* Returns the number of rows in the [[DataFrame]].
* @group action
*/
- override def count(): Long
+ override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
* @group rdd
*/
- override def repartition(numPartitions: Int): DataFrame
+ override def repartition(numPartitions: Int): DataFrame = {
+ sqlContext.createDataFrame(rdd.repartition(numPartitions), schema)
+ }
/**
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
* @group dfops
*/
- override def distinct: DataFrame
+ override def distinct: DataFrame = Distinct(logicalPlan)
/**
* @group basic
*/
- override def persist(): this.type
+ override def persist(): this.type = {
+ sqlContext.cacheManager.cacheQuery(this)
+ this
+ }
/**
* @group basic
*/
- override def persist(newLevel: StorageLevel): this.type
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+ this
+ }
/**
* @group basic
*/
- override def unpersist(blocking: Boolean): this.type
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+ this
+ }
/////////////////////////////////////////////////////////////////////////////
// I/O
@@ -664,7 +847,11 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
* @group rdd
*/
- def rdd: RDD[Row]
+ def rdd: RDD[Row] = {
+ // use a local variable to make sure the map closure doesn't capture the whole DataFrame
+ val schema = this.schema
+ queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
+ }
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
@@ -684,7 +871,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
*
* @group basic
*/
- def registerTempTable(tableName: String): Unit
+ def registerTempTable(tableName: String): Unit = {
+ sqlContext.registerDataFrameAsTable(this, tableName)
+ }
/**
* Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema.
@@ -692,7 +881,13 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* using the `parquetFile` function in [[SQLContext]].
* @group output
*/
- def saveAsParquetFile(path: String): Unit
+ def saveAsParquetFile(path: String): Unit = {
+ if (sqlContext.conf.parquetUseDataSourceApi) {
+ save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path))
+ } else {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+ }
/**
* :: Experimental ::
@@ -747,9 +942,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group output
*/
@Experimental
- def saveAsTable(
- tableName: String,
- source: String): Unit = {
+ def saveAsTable(tableName: String, source: String): Unit = {
saveAsTable(tableName, source, SaveMode.ErrorIfExists)
}
@@ -765,10 +958,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group output
*/
@Experimental
- def saveAsTable(
- tableName: String,
- source: String,
- mode: SaveMode): Unit = {
+ def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = {
saveAsTable(tableName, source, mode, Map.empty[String, String])
}
@@ -809,7 +999,18 @@ trait DataFrame extends RDDApi[Row] with Serializable {
tableName: String,
source: String,
mode: SaveMode,
- options: Map[String, String]): Unit
+ options: Map[String, String]): Unit = {
+ val cmd =
+ CreateTableUsingAsSelect(
+ tableName,
+ source,
+ temporary = false,
+ mode,
+ options,
+ logicalPlan)
+
+ sqlContext.executePlan(cmd).toRdd
+ }
/**
* :: Experimental ::
@@ -882,7 +1083,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
def save(
source: String,
mode: SaveMode,
- options: Map[String, String]): Unit
+ options: Map[String, String]): Unit = {
+ ResolvedDataSource(sqlContext, source, mode, options, this)
+ }
/**
* :: Experimental ::
@@ -890,7 +1093,10 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* @group output
*/
@Experimental
- def insertInto(tableName: String, overwrite: Boolean): Unit
+ def insertInto(tableName: String, overwrite: Boolean): Unit = {
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
+ }
/**
* :: Experimental ::
@@ -905,7 +1111,31 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
* @group rdd
*/
- def toJSON: RDD[String]
+ def toJSON: RDD[String] = {
+ val rowSchema = this.schema
+ this.mapPartitions { iter =>
+ val writer = new CharArrayWriter()
+ // create the Generator without separator inserted between 2 records
+ val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+
+ new Iterator[String] {
+ override def hasNext = iter.hasNext
+ override def next(): String = {
+ JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
+ gen.flush()
+
+ val json = writer.toString
+ if (hasNext) {
+ writer.reset()
+ } else {
+ gen.close()
+ }
+
+ json
+ }
+ }
+ }
+ }
////////////////////////////////////////////////////////////////////////////
// JDBC Write Support
@@ -919,7 +1149,21 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* exists.
* @group output
*/
- def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit
+ def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
+ val conn = DriverManager.getConnection(url)
+ try {
+ if (allowExisting) {
+ val sql = s"DROP TABLE IF EXISTS $table"
+ conn.prepareStatement(sql).executeUpdate()
+ }
+ val schema = JDBCWriteDetails.schemaString(this, url)
+ val sql = s"CREATE TABLE $table ($schema)"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ JDBCWriteDetails.saveTable(this, url, table)
+ }
/**
* Save this RDD to a JDBC database at `url` under the table name `table`.
@@ -933,8 +1177,18 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
* @group output
*/
- def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit
-
+ def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
+ if (overwrite) {
+ val conn = DriverManager.getConnection(url)
+ try {
+ val sql = s"TRUNCATE TABLE $table"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ }
+ JDBCWriteDetails.saveTable(this, url, table)
+ }
////////////////////////////////////////////////////////////////////////////
// for Python API
@@ -943,5 +1197,9 @@ trait DataFrame extends RDDApi[Row] with Serializable {
/**
* Converts a JavaRDD to a PythonRDD.
*/
- protected[sql] def javaToPython: JavaRDD[Array[Byte]]
+ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ SerDeUtil.javaToPython(jrdd)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
deleted file mode 100644
index 25bc9d9292..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ /dev/null
@@ -1,483 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import java.io.CharArrayWriter
-import java.sql.DriverManager
-
-import scala.language.implicitConversions
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.TypeTag
-import scala.collection.JavaConversions._
-
-import com.fasterxml.jackson.core.JsonFactory
-
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExplainCommand, LogicalRDD, EvaluatePython}
-import org.apache.spark.sql.jdbc.JDBCWriteDetails
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{NumericType, StructType}
-
-/**
- * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
- */
-private[sql] class DataFrameImpl protected[sql](
- @transient override val sqlContext: SQLContext,
- @transient val queryExecution: SQLContext#QueryExecution)
- extends DataFrame {
-
- /**
- * A constructor that automatically analyzes the logical plan.
- *
- * This reports error eagerly as the [[DataFrame]] is constructed, unless
- * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
- */
- def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
- this(sqlContext, {
- val qe = sqlContext.executePlan(logicalPlan)
- if (sqlContext.conf.dataFrameEagerAnalysis) {
- qe.analyzed // This should force analysis and throw errors if there are any
- }
- qe
- })
- }
-
- @transient protected[sql] override val logicalPlan: LogicalPlan = queryExecution.logical match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command |
- _: InsertIntoTable |
- _: CreateTableAsSelect[_] |
- _: CreateTableUsingAsSelect |
- _: WriteToFile =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- queryExecution.logical
- }
-
- /**
- * An implicit conversion function internal to this class for us to avoid doing
- * "new DataFrameImpl(...)" everywhere.
- */
- @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
- new DataFrameImpl(sqlContext, logicalPlan)
- }
-
- protected[sql] def resolve(colName: String): NamedExpression = {
- queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
- throw new AnalysisException(
- s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
- }
- }
-
- protected[sql] def numericColumns: Seq[Expression] = {
- schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
- queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
- }
- }
-
- override def toDF(colNames: String*): DataFrame = {
- require(schema.size == colNames.size,
- "The number of columns doesn't match.\n" +
- "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
- "New column names: " + colNames.mkString(", "))
-
- val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
- apply(oldName).as(newName)
- }
- select(newCols :_*)
- }
-
- override def schema: StructType = queryExecution.analyzed.schema
-
- override def dtypes: Array[(String, String)] = schema.fields.map { field =>
- (field.name, field.dataType.toString)
- }
-
- override def columns: Array[String] = schema.fields.map(_.name)
-
- override def printSchema(): Unit = println(schema.treeString)
-
- override def explain(extended: Boolean): Unit = {
- ExplainCommand(
- logicalPlan,
- extended = extended).queryExecution.executedPlan.executeCollect().map {
- r => println(r.getString(0))
- }
- }
-
- override def isLocal: Boolean = {
- logicalPlan.isInstanceOf[LocalRelation]
- }
-
- /**
- * Internal API for Python
- */
- private[sql] def showString(): String = {
- val data = take(20)
- val numCols = schema.fieldNames.length
-
- // For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
- row.toSeq.map { cell =>
- val str = if (cell == null) "null" else cell.toString
- if (str.length > 20) str.substring(0, 17) + "..." else str
- } : Seq[String]
- }
-
- // Compute the width of each column
- val colWidths = Array.fill(numCols)(0)
- for (row <- rows) {
- for ((cell, i) <- row.zipWithIndex) {
- colWidths(i) = math.max(colWidths(i), cell.length)
- }
- }
-
- // Pad the cells
- rows.map { row =>
- row.zipWithIndex.map { case (cell, i) =>
- String.format(s"%-${colWidths(i)}s", cell)
- }.mkString(" ")
- }.mkString("\n")
- }
-
- override def show(): Unit = {
- println(showString())
- }
-
- override def join(right: DataFrame): DataFrame = {
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
- }
-
- override def join(right: DataFrame, joinExprs: Column): DataFrame = {
- Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
- }
-
- override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
- Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
- }
-
- override def sort(sortCol: String, sortCols: String*): DataFrame = {
- sort((sortCol +: sortCols).map(apply) :_*)
- }
-
- override def sort(sortExprs: Column*): DataFrame = {
- val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
- col.expr match {
- case expr: SortOrder =>
- expr
- case expr: Expression =>
- SortOrder(expr, Ascending)
- }
- }
- Sort(sortOrder, global = true, logicalPlan)
- }
-
- override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
- sort(sortCol, sortCols :_*)
- }
-
- override def orderBy(sortExprs: Column*): DataFrame = {
- sort(sortExprs :_*)
- }
-
- override def col(colName: String): Column = colName match {
- case "*" =>
- Column(ResolvedStar(schema.fieldNames.map(resolve)))
- case _ =>
- val expr = resolve(colName)
- Column(sqlContext, Project(Seq(expr), logicalPlan), expr)
- }
-
- override def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
-
- override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
-
- override def select(cols: Column*): DataFrame = {
- val namedExpressions = cols.map {
- case Column(expr: NamedExpression) => expr
- case Column(expr: Expression) => Alias(expr, expr.prettyString)()
- }
- Project(namedExpressions.toSeq, logicalPlan)
- }
-
- override def select(col: String, cols: String*): DataFrame = {
- select((col +: cols).map(Column(_)) :_*)
- }
-
- override def selectExpr(exprs: String*): DataFrame = {
- select(exprs.map { expr =>
- Column(new SqlParser().parseExpression(expr))
- }: _*)
- }
-
- override def withColumn(colName: String, col: Column): DataFrame = {
- select(Column("*"), col.as(colName))
- }
-
- override def withColumnRenamed(existingName: String, newName: String): DataFrame = {
- val resolver = sqlContext.analyzer.resolver
- val colNames = schema.map { field =>
- val name = field.name
- if (resolver(name, existingName)) Column(name).as(newName) else Column(name)
- }
- select(colNames :_*)
- }
-
- override def filter(condition: Column): DataFrame = {
- Filter(condition.expr, logicalPlan)
- }
-
- override def filter(conditionExpr: String): DataFrame = {
- filter(Column(new SqlParser().parseExpression(conditionExpr)))
- }
-
- override def where(condition: Column): DataFrame = {
- filter(condition)
- }
-
- override def groupBy(cols: Column*): GroupedData = {
- new GroupedData(this, cols.map(_.expr))
- }
-
- override def groupBy(col1: String, cols: String*): GroupedData = {
- val colNames: Seq[String] = col1 +: cols
- new GroupedData(this, colNames.map(colName => resolve(colName)))
- }
-
- override def limit(n: Int): DataFrame = {
- Limit(Literal(n), logicalPlan)
- }
-
- override def unionAll(other: DataFrame): DataFrame = {
- Union(logicalPlan, other.logicalPlan)
- }
-
- override def intersect(other: DataFrame): DataFrame = {
- Intersect(logicalPlan, other.logicalPlan)
- }
-
- override def except(other: DataFrame): DataFrame = {
- Except(logicalPlan, other.logicalPlan)
- }
-
- override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
- }
-
- override def explode[A <: Product : TypeTag]
- (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
- val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
- val attributes = schema.toAttributes
- val rowFunction =
- f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
- val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
-
- Generate(generator, join = true, outer = false, None, logicalPlan)
- }
-
- override def explode[A, B : TypeTag](
- inputColumn: String,
- outputColumn: String)(
- f: A => TraversableOnce[B]): DataFrame = {
- val dataType = ScalaReflection.schemaFor[B].dataType
- val attributes = AttributeReference(outputColumn, dataType)() :: Nil
- def rowFunction(row: Row) = {
- f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
- }
- val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
-
- Generate(generator, join = true, outer = false, None, logicalPlan)
-
- }
-
- /////////////////////////////////////////////////////////////////////////////
- // RDD API
- /////////////////////////////////////////////////////////////////////////////
-
- override def head(n: Int): Array[Row] = limit(n).collect()
-
- override def head(): Row = head(1).head
-
- override def first(): Row = head()
-
- override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
-
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
-
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
- rdd.mapPartitions(f)
- }
-
- override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
-
- override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
-
- override def take(n: Int): Array[Row] = head(n)
-
- override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
-
- override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
-
- override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
-
- override def repartition(numPartitions: Int): DataFrame = {
- sqlContext.createDataFrame(rdd.repartition(numPartitions), schema)
- }
-
- override def distinct: DataFrame = Distinct(logicalPlan)
-
- override def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
- this
- }
-
- override def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
- this
- }
-
- /////////////////////////////////////////////////////////////////////////////
- // I/O
- /////////////////////////////////////////////////////////////////////////////
-
- override def rdd: RDD[Row] = {
- // use a local variable to make sure the map closure doesn't capture the whole DataFrame
- val schema = this.schema
- queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
- }
-
- override def registerTempTable(tableName: String): Unit = {
- sqlContext.registerDataFrameAsTable(this, tableName)
- }
-
- override def saveAsParquetFile(path: String): Unit = {
- if (sqlContext.conf.parquetUseDataSourceApi) {
- save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path))
- } else {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
- }
-
- override def saveAsTable(
- tableName: String,
- source: String,
- mode: SaveMode,
- options: Map[String, String]): Unit = {
- val cmd =
- CreateTableUsingAsSelect(
- tableName,
- source,
- temporary = false,
- mode,
- options,
- logicalPlan)
-
- sqlContext.executePlan(cmd).toRdd
- }
-
- override def save(
- source: String,
- mode: SaveMode,
- options: Map[String, String]): Unit = {
- ResolvedDataSource(sqlContext, source, mode, options, this)
- }
-
- override def insertInto(tableName: String, overwrite: Boolean): Unit = {
- sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
- Map.empty, logicalPlan, overwrite)).toRdd
- }
-
- override def toJSON: RDD[String] = {
- val rowSchema = this.schema
- this.mapPartitions { iter =>
- val writer = new CharArrayWriter()
- // create the Generator without separator inserted between 2 records
- val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
-
- new Iterator[String] {
- override def hasNext = iter.hasNext
- override def next(): String = {
- JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
- gen.flush()
-
- val json = writer.toString
- if (hasNext) {
- writer.reset()
- } else {
- gen.close()
- }
-
- json
- }
- }
- }
- }
-
- def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
- val conn = DriverManager.getConnection(url)
- try {
- if (allowExisting) {
- val sql = s"DROP TABLE IF EXISTS $table"
- conn.prepareStatement(sql).executeUpdate()
- }
- val schema = JDBCWriteDetails.schemaString(this, url)
- val sql = s"CREATE TABLE $table ($schema)"
- conn.prepareStatement(sql).executeUpdate()
- } finally {
- conn.close()
- }
- JDBCWriteDetails.saveTable(this, url, table)
- }
-
- def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
- if (overwrite) {
- val conn = DriverManager.getConnection(url)
- try {
- val sql = s"TRUNCATE TABLE $table"
- conn.prepareStatement(sql).executeUpdate()
- } finally {
- conn.close()
- }
- }
- JDBCWriteDetails.saveTable(this, url, table)
- }
-
- ////////////////////////////////////////////////////////////////////////////
- // for Python API
- ////////////////////////////////////////////////////////////////////////////
- protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 17158303b8..d001752659 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.NumericType
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
@Experimental
-class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
+class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
deleted file mode 100644
index b48b682b36..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.types.StructType
-
-private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column {
-
- def this(name: String) = this(name match {
- case "*" => UnresolvedStar(None)
- case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
- case _ => UnresolvedAttribute(name)
- })
-
- private def err[T](): T = {
- throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn")
- }
-
- override def toString = expr.prettyString
-
- override def isComputable: Boolean = false
-
- override val sqlContext: SQLContext = null
-
- override def queryExecution = err()
-
- protected[sql] override def logicalPlan: LogicalPlan = err()
-
- override def toDF(colNames: String*): DataFrame = err()
-
- override def schema: StructType = err()
-
- override def dtypes: Array[(String, String)] = err()
-
- override def columns: Array[String] = err()
-
- override def printSchema(): Unit = err()
-
- override def show(): Unit = err()
-
- override def isLocal: Boolean = false
-
- override def join(right: DataFrame): DataFrame = err()
-
- override def join(right: DataFrame, joinExprs: Column): DataFrame = err()
-
- override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = err()
-
- override def sort(sortCol: String, sortCols: String*): DataFrame = err()
-
- override def sort(sortExprs: Column*): DataFrame = err()
-
- override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
-
- override def orderBy(sortExprs: Column*): DataFrame = err()
-
- override def col(colName: String): Column = err()
-
- override def select(cols: Column*): DataFrame = err()
-
- override def select(col: String, cols: String*): DataFrame = err()
-
- override def selectExpr(exprs: String*): DataFrame = err()
-
- override def withColumn(colName: String, col: Column): DataFrame = err()
-
- override def withColumnRenamed(existingName: String, newName: String): DataFrame = err()
-
- override def filter(condition: Column): DataFrame = err()
-
- override def filter(conditionExpr: String): DataFrame = err()
-
- override def where(condition: Column): DataFrame = err()
-
- override def groupBy(cols: Column*): GroupedData = err()
-
- override def groupBy(col1: String, cols: String*): GroupedData = err()
-
- override def limit(n: Int): DataFrame = err()
-
- override def unionAll(other: DataFrame): DataFrame = err()
-
- override def intersect(other: DataFrame): DataFrame = err()
-
- override def except(other: DataFrame): DataFrame = err()
-
- override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
-
- override def explode[A <: Product : TypeTag]
- (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()
-
- override def explode[A, B : TypeTag](
- inputColumn: String,
- outputColumn: String)(
- f: A => TraversableOnce[B]): DataFrame = err()
-
- /////////////////////////////////////////////////////////////////////////////
-
- override def head(n: Int): Array[Row] = err()
-
- override def head(): Row = err()
-
- override def first(): Row = err()
-
- override def map[R: ClassTag](f: Row => R): RDD[R] = err()
-
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = err()
-
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = err()
-
- override def foreach(f: Row => Unit): Unit = err()
-
- override def foreachPartition(f: Iterator[Row] => Unit): Unit = err()
-
- override def take(n: Int): Array[Row] = err()
-
- override def collect(): Array[Row] = err()
-
- override def collectAsList(): java.util.List[Row] = err()
-
- override def count(): Long = err()
-
- override def repartition(numPartitions: Int): DataFrame = err()
-
- override def distinct: DataFrame = err()
-
- override def persist(): this.type = err()
-
- override def persist(newLevel: StorageLevel): this.type = err()
-
- override def unpersist(blocking: Boolean): this.type = err()
-
- override def rdd: RDD[Row] = err()
-
- override def registerTempTable(tableName: String): Unit = err()
-
- override def saveAsParquetFile(path: String): Unit = err()
-
- override def saveAsTable(
- tableName: String,
- source: String,
- mode: SaveMode,
- options: Map[String, String]): Unit = err()
-
- override def save(
- source: String,
- mode: SaveMode,
- options: Map[String, String]): Unit = err()
-
- override def insertInto(tableName: String, overwrite: Boolean): Unit = err()
-
- def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = err()
-
- def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = err()
-
- override def toJSON: RDD[String] = err()
-
- protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err()
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a63d733ece..928b0deb61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -28,49 +28,10 @@ class ColumnExpressionSuite extends QueryTest {
// TODO: Add test cases for bitwise operations.
- test("computability check") {
- def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true)
-
- def shouldNotBeComputable(c: Column): Unit = {
- assert(c.isComputable === false)
- intercept[UnsupportedOperationException] { c.head() }
- }
-
- shouldBeComputable(testData2("a"))
- shouldBeComputable(testData2("b"))
-
- shouldBeComputable(testData2("a") + testData2("b"))
- shouldBeComputable(testData2("a") + testData2("b") + 1)
-
- shouldBeComputable(-testData2("a"))
- shouldBeComputable(!testData2("a"))
-
- shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
- shouldNotBeComputable(
- testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
- shouldNotBeComputable(
- testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
-
- // Literals and unresolved columns should not be computable.
- shouldNotBeComputable(col("1"))
- shouldNotBeComputable(col("1") + 2)
- shouldNotBeComputable(lit(100))
- shouldNotBeComputable(lit(100) + 10)
- shouldNotBeComputable(-col("1"))
- shouldNotBeComputable(!col("1"))
-
- // Getting data from different frames should not be computable.
- shouldNotBeComputable(testData2("a") + testData("key"))
- shouldNotBeComputable(testData2("a") + 1 + testData("key"))
-
- // Aggregate functions alone should not be computable.
- shouldNotBeComputable(sum(testData2("a")))
- }
-
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
- checkAnswer(df("a") + df("b"), Seq(Row(3)))
- checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
+ checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
+ checkAnswer(df.select(df("a") + df("b").as("c")), Seq(Row(3)))
}
test("star") {
@@ -78,7 +39,6 @@ class ColumnExpressionSuite extends QueryTest {
}
test("star qualified by data frame object") {
- // This is not yet supported.
val df = testData.toDF
val goldAnswer = df.collect().toSeq
checkAnswer(df.select(df("*")), goldAnswer)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f31bc38922..6b9b3a8425 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -56,10 +56,7 @@ class DataFrameSuite extends QueryTest {
test("dataframe toString") {
assert(testData.toString === "[key: int, value: string]")
- assert(testData("key").toString === "[key: int]")
- }
-
- test("incomputable toString") {
+ assert(testData("key").toString === "key")
assert($"test".toString === "test")
}
@@ -431,7 +428,7 @@ class DataFrameSuite extends QueryTest {
test("apply on query results (SPARK-5462)") {
val df = testData.sqlContext.sql("select key from testData")
- checkAnswer(df("key"), testData.select('key).collect().toSeq)
+ checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
}
}