aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-11 18:32:48 -0800
committerReynold Xin <rxin@databricks.com>2015-02-11 18:32:48 -0800
commitd931b01dcaaf009dcf68dcfe83428bd7f9e857cc (patch)
treebbffeaa04ba6edd6e89ab68edf320ce30e1c11ad /sql/core
parentfa6bdc6e819f9338248b952ec578bcd791ddbf6d (diff)
downloadspark-d931b01dcaaf009dcf68dcfe83428bd7f9e857cc.tar.gz
spark-d931b01dcaaf009dcf68dcfe83428bd7f9e857cc.tar.bz2
spark-d931b01dcaaf009dcf68dcfe83428bd7f9e857cc.zip
[SQL] Two DataFrame fixes.
- Removed DataFrame.apply for projection & filtering since they are extremely confusing. - Added implicits for RDD[Int], RDD[Long], and RDD[String] Author: Reynold Xin <rxin@databricks.com> Closes #4543 from rxin/df-cleanup and squashes the following commits: 81ec915 [Reynold Xin] [SQL] More DataFrame fixes.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala54
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala55
5 files changed, 119 insertions, 57 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 17900c5ee3..327cf87f30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -85,6 +85,14 @@ trait DataFrame extends RDDApi[Row] {
protected[sql] def logicalPlan: LogicalPlan
+ override def toString =
+ try {
+ schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
+ } catch {
+ case NonFatal(e) =>
+ s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+ }
+
/** Left here for backward compatibility. */
@deprecated("1.3.0", "use toDataFrame")
def toSchemaRDD: DataFrame = this
@@ -92,13 +100,9 @@ trait DataFrame extends RDDApi[Row] {
/**
* Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
*/
- def toDataFrame: DataFrame = this
-
- override def toString =
- try schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") catch {
- case NonFatal(e) =>
- s"Invalid tree; ${e.getMessage}:\n$queryExecution"
- }
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDataFrame("1")` as invoking this toDataFrame and then apply on the returned DataFrame.
+ def toDataFrame(): DataFrame = this
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
@@ -235,16 +239,6 @@ trait DataFrame extends RDDApi[Row] {
def col(colName: String): Column
/**
- * Selects a set of expressions, wrapped in a Product.
- * {{{
- * // The following two are equivalent:
- * df.apply(($"colA", $"colB" + 1))
- * df.select($"colA", $"colB" + 1)
- * }}}
- */
- def apply(projection: Product): DataFrame
-
- /**
* Returns a new [[DataFrame]] with an alias set.
*/
def as(alias: String): DataFrame
@@ -318,17 +312,6 @@ trait DataFrame extends RDDApi[Row] {
def where(condition: Column): DataFrame
/**
- * Filters rows using the given condition. This is a shorthand meant for Scala.
- * {{{
- * // The following are equivalent:
- * peopleDf.filter($"age" > 15)
- * peopleDf.where($"age" > 15)
- * peopleDf($"age" > 15)
- * }}}
- */
- def apply(condition: Column): DataFrame
-
- /**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
* See [[GroupedData]] for all the available aggregate functions.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 41da4424ae..3863df5318 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -49,8 +49,10 @@ private[sql] class DataFrameImpl protected[sql](
extends DataFrame {
/**
- * A constructor that automatically analyzes the logical plan. This reports error eagerly
- * as the [[DataFrame]] is constructed.
+ * A constructor that automatically analyzes the logical plan.
+ *
+ * This reports error eagerly as the [[DataFrame]] is constructed, unless
+ * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
*/
def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
this(sqlContext, {
@@ -158,7 +160,7 @@ private[sql] class DataFrameImpl protected[sql](
}
override def show(): Unit = {
- println(showString)
+ println(showString())
}
override def join(right: DataFrame): DataFrame = {
@@ -205,14 +207,6 @@ private[sql] class DataFrameImpl protected[sql](
Column(sqlContext, Project(Seq(expr), logicalPlan), expr)
}
- override def apply(projection: Product): DataFrame = {
- require(projection.productArity >= 1)
- select(projection.productIterator.map {
- case c: Column => c
- case o: Any => Column(Literal(o))
- }.toSeq :_*)
- }
-
override def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
@@ -259,10 +253,6 @@ private[sql] class DataFrameImpl protected[sql](
filter(condition)
}
- override def apply(condition: Column): DataFrame = {
- filter(condition)
- }
-
override def groupBy(cols: Column*): GroupedData = {
new GroupedData(this, cols.map(_.expr))
}
@@ -323,7 +313,7 @@ private[sql] class DataFrameImpl protected[sql](
override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
override def repartition(numPartitions: Int): DataFrame = {
- sqlContext.applySchema(rdd.repartition(numPartitions), schema)
+ sqlContext.createDataFrame(rdd.repartition(numPartitions), schema)
}
override def distinct: DataFrame = Distinct(logicalPlan)
@@ -401,7 +391,7 @@ private[sql] class DataFrameImpl protected[sql](
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
new Iterator[String] {
- override def hasNext() = iter.hasNext
+ override def hasNext = iter.hasNext
override def next(): String = {
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
gen.flush()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 494e49c131..4f9d92d976 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -80,8 +80,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def col(colName: String): Column = err()
- override def apply(projection: Product): DataFrame = err()
-
override def select(cols: Column*): DataFrame = err()
override def select(col: String, cols: String*): DataFrame = err()
@@ -98,8 +96,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def where(condition: Column): DataFrame = err()
- override def apply(condition: Column): DataFrame = err()
-
override def groupBy(cols: Column*): GroupedData = err()
override def groupBy(col1: String, cols: String*): GroupedData = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index fd121ce056..ca5e62f295 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -180,21 +180,59 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
object implicits {
// scalastyle:on
- /**
- * Creates a DataFrame from an RDD of case classes.
- *
- * @group userf
- */
+
+ /** Creates a DataFrame from an RDD of case classes or tuples. */
implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
self.createDataFrame(rdd)
}
- /**
- * Creates a DataFrame from a local Seq of Product.
- */
+ /** Creates a DataFrame from a local Seq of Product. */
implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
self.createDataFrame(data)
}
+
+ // Do NOT add more implicit conversions. They are likely to break source compatibility by
+ // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
+ // because of [[DoubleRDDFunctions]].
+
+ /** Creates a single column DataFrame from an RDD[Int]. */
+ implicit def intRddToDataFrame(data: RDD[Int]): DataFrame = {
+ val dataType = IntegerType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setInt(0, v)
+ row: Row
+ }
+ }
+ self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ }
+
+ /** Creates a single column DataFrame from an RDD[Long]. */
+ implicit def longRddToDataFrame(data: RDD[Long]): DataFrame = {
+ val dataType = LongType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setLong(0, v)
+ row: Row
+ }
+ }
+ self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ }
+
+ /** Creates a single column DataFrame from an RDD[String]. */
+ implicit def stringRddToDataFrame(data: RDD[String]): DataFrame = {
+ val dataType = StringType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setString(0, v)
+ row: Row
+ }
+ }
+ self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
new file mode 100644
index 0000000000..8fa830dd93
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -0,0 +1,55 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc}
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+
+class DataFrameImplicitsSuite extends QueryTest {
+
+ test("RDD of tuples") {
+ checkAnswer(
+ sc.parallelize(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ (1 to 10).map(i => Row(i, i.toString)))
+ }
+
+ test("Seq of tuples") {
+ checkAnswer(
+ (1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ (1 to 10).map(i => Row(i, i.toString)))
+ }
+
+ test("RDD[Int]") {
+ checkAnswer(
+ sc.parallelize(1 to 10).toDataFrame("intCol"),
+ (1 to 10).map(i => Row(i)))
+ }
+
+ test("RDD[Long]") {
+ checkAnswer(
+ sc.parallelize(1L to 10L).toDataFrame("longCol"),
+ (1L to 10L).map(i => Row(i)))
+ }
+
+ test("RDD[String]") {
+ checkAnswer(
+ sc.parallelize(1 to 10).map(_.toString).toDataFrame("stringCol"),
+ (1 to 10).map(i => Row(i.toString)))
+ }
+}