aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala13
4 files changed, 39 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 9ff06ac362..16979c9ed4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -180,7 +180,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
- import sqlContext.implicits.createDataFrame
+ import sqlContext.implicits._
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b0e95908ee..9d5d6e78bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -66,27 +66,44 @@ trait Column extends DataFrame {
*/
def isComputable: Boolean
+ /** Removes the top project so we can get to the underlying plan. */
+ private def stripProject(p: LogicalPlan): LogicalPlan = p match {
+ case Project(_, child) => child
+ case p => sys.error("Unexpected logical plan (expected Project): " + p)
+ }
+
private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val plan = Project(Seq(expr match {
+ val namedExpr = expr match {
case named: NamedExpression => named
case unnamed: Expression => Alias(unnamed, "col")()
- }), baseCol.plan)
+ }
+ val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
Column(baseCol.sqlContext, plan, expr)
}
+ /**
+ * Construct a new column based on the expression and the other column value.
+ *
+ * There are two cases that can happen here:
+ * If otherValue is a constant, it is first turned into a Column.
+ * If otherValue is a Column, then:
+ * - If this column and otherValue are both computable and come from the same logical plan,
+ * then we can construct a ComputableColumn by applying a Project on top of the base plan.
+ * - If this column is not computable, but otherValue is computable, then we can construct
+ * a ComputableColumn based on otherValue's base plan.
+ * - If this column is computable, but otherValue is not, then we can construct a
+ * ComputableColumn based on this column's base plan.
+ * - If neither columns are computable, then we create an IncomputableColumn.
+ */
private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
- // Removes all the top level projection and subquery so we can get to the underlying plan.
- @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
- case Project(_, child) => stripProject(child)
- case Subquery(_, child) => stripProject(child)
- case _ => p
- }
-
+ // lit(otherValue) returns a Column always.
(this, lit(otherValue)) match {
case (left: ComputableColumn, right: ComputableColumn) =>
if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
computableCol(right, newExpr(right))
} else {
+ // We don't want to throw an exception here because "df1("a") === df2("b")" can be
+ // a valid expression for join conditions, even though standalone they are not valid.
Column(newExpr(right))
}
case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 523911d108..05ac1623d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @group userf
*/
- implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+ implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}
/**
* Creates a DataFrame from a local Seq of Product.
*/
- implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1d71039872..e3e6f652ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
@@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest {
shouldBeComputable(-testData2("a"))
shouldBeComputable(!testData2("a"))
- shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
- shouldBeComputable(
+ shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
+ shouldNotBeComputable(
testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
- shouldBeComputable(
+ shouldNotBeComputable(
testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
// Literals and unresolved columns should not be computable.
@@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest {
shouldNotBeComputable(sum(testData2("a")))
}
+ test("collect on column produced by a binary operator") {
+ val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+ checkAnswer(df("a") + df("b"), Seq(Row(3)))
+ checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
+ }
+
test("star") {
checkAnswer(testData.select($"*"), testData.collect().toSeq)
}