From 7e24249af1e2f896328ef0402fa47db78cb6f9ec Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 10 Feb 2015 19:50:44 -0800 Subject: [SQL][DataFrame] Fix column computability bug. Do not recursively strip out projects. Only strip the first level project. ```scala df("colA") + df("colB").as("colC") ``` Previously, the above would construct an invalid plan. Author: Reynold Xin Closes #4519 from rxin/computability and squashes the following commits: 87ff763 [Reynold Xin] Code review feedback. 015c4fc [Reynold Xin] [SQL][DataFrame] Fix column computability. --- .../main/scala/org/apache/spark/sql/Column.scala | 35 ++++++++++++++++------ .../scala/org/apache/spark/sql/SQLContext.scala | 4 +-- .../apache/spark/sql/ColumnExpressionSuite.scala | 13 ++++++-- 3 files changed, 38 insertions(+), 14 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b0e95908ee..9d5d6e78bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -66,27 +66,44 @@ trait Column extends DataFrame { */ def isComputable: Boolean + /** Removes the top project so we can get to the underlying plan. */ + private def stripProject(p: LogicalPlan): LogicalPlan = p match { + case Project(_, child) => child + case p => sys.error("Unexpected logical plan (expected Project): " + p) + } + private def computableCol(baseCol: ComputableColumn, expr: Expression) = { - val plan = Project(Seq(expr match { + val namedExpr = expr match { case named: NamedExpression => named case unnamed: Expression => Alias(unnamed, "col")() - }), baseCol.plan) + } + val plan = Project(Seq(namedExpr), stripProject(baseCol.plan)) Column(baseCol.sqlContext, plan, expr) } + /** + * Construct a new column based on the expression and the other column value. + * + * There are two cases that can happen here: + * If otherValue is a constant, it is first turned into a Column. + * If otherValue is a Column, then: + * - If this column and otherValue are both computable and come from the same logical plan, + * then we can construct a ComputableColumn by applying a Project on top of the base plan. + * - If this column is not computable, but otherValue is computable, then we can construct + * a ComputableColumn based on otherValue's base plan. + * - If this column is computable, but otherValue is not, then we can construct a + * ComputableColumn based on this column's base plan. + * - If neither columns are computable, then we create an IncomputableColumn. + */ private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = { - // Removes all the top level projection and subquery so we can get to the underlying plan. - @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match { - case Project(_, child) => stripProject(child) - case Subquery(_, child) => stripProject(child) - case _ => p - } - + // lit(otherValue) returns a Column always. (this, lit(otherValue)) match { case (left: ComputableColumn, right: ComputableColumn) => if (stripProject(left.plan).sameResult(stripProject(right.plan))) { computableCol(right, newExpr(right)) } else { + // We don't want to throw an exception here because "df1("a") === df2("b")" can be + // a valid expression for join conditions, even though standalone they are not valid. Column(newExpr(right)) } case (left: ComputableColumn, right) => computableCol(left, newExpr(right)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 523911d108..05ac1623d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { + implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { self.createDataFrame(rdd) } /** * Creates a DataFrame from a local Seq of Product. */ - implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { self.createDataFrame(data) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1d71039872..e3e6f652ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} @@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest { shouldBeComputable(-testData2("a")) shouldBeComputable(!testData2("a")) - shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b")) - shouldBeComputable( + shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b")) + shouldNotBeComputable( testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d")) - shouldBeComputable( + shouldNotBeComputable( testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b")) // Literals and unresolved columns should not be computable. @@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest { shouldNotBeComputable(sum(testData2("a"))) } + test("collect on column produced by a binary operator") { + val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c") + checkAnswer(df("a") + df("b"), Seq(Row(3))) + checkAnswer(df("a") + df("b").as("c"), Seq(Row(3))) + } + test("star") { checkAnswer(testData.select($"*"), testData.collect().toSeq) } -- cgit v1.2.3