aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08 /sql
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala)21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala2
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala17
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala6
37 files changed, 248 insertions, 182 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index f959a50564..a7cd4124e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -152,7 +152,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
override lazy val resolved = false
- override def newInstance = this
+ override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9d5d6e78bd..f6ecee1af8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql
-import scala.annotation.tailrec
import scala.language.implicitConversions
-import org.apache.spark.sql.Dsl.lit
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._
@@ -127,7 +126,7 @@ trait Column extends DataFrame {
* df.select( -df("amount") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.select( negate(col("amount") );
* }}}
*/
@@ -140,7 +139,7 @@ trait Column extends DataFrame {
* df.filter( !df("isActive") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( not(df.col("isActive")) );
* }}
*/
@@ -153,7 +152,7 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
@@ -168,7 +167,7 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
@@ -182,7 +181,7 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
*/
@@ -198,7 +197,7 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
*/
@@ -213,7 +212,7 @@ trait Column extends DataFrame {
* people.select( people("age") > 21 )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
*/
@@ -228,7 +227,7 @@ trait Column extends DataFrame {
* people.select( people("age") > lit(21) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4f8f19e2c1..e21e989f36 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -48,7 +48,7 @@ private[sql] object DataFrame {
* }}}
*
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
- * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL.
+ * defined in: [[DataFrame]] (this class), [[Column]], [[functions]] for the DSL.
*
* To select a column from the data frame, use the apply method:
* {{{
@@ -94,27 +94,27 @@ trait DataFrame extends RDDApi[Row] with Serializable {
}
/** Left here for backward compatibility. */
- @deprecated("1.3.0", "use toDataFrame")
+ @deprecated("1.3.0", "use toDF")
def toSchemaRDD: DataFrame = this
/**
* Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
*/
// This is declared with parentheses to prevent the Scala compiler from treating
- // `rdd.toDataFrame("1")` as invoking this toDataFrame and then apply on the returned DataFrame.
- def toDataFrame(): DataFrame = this
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = this
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
- * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
- * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * rdd.toDF // this implicit conversion creates a DataFrame with column name _1 and _2
+ * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
*/
@scala.annotation.varargs
- def toDataFrame(colNames: String*): DataFrame
+ def toDF(colNames: String*): DataFrame
/** Returns the schema of this [[DataFrame]]. */
def schema: StructType
@@ -132,7 +132,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
def explain(extended: Boolean): Unit
/** Only prints the physical plan to the console for debugging purpose. */
- def explain(): Unit = explain(false)
+ def explain(): Unit = explain(extended = false)
/**
* Returns true if the `collect` and `take` methods can be run locally
@@ -179,11 +179,11 @@ trait DataFrame extends RDDApi[Row] with Serializable {
*
* {{{
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df1.join(df2, "outer", $"df1Key" === $"df2Key")
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df1.join(df2, "outer", col("df1Key") === col("df2Key"));
* }}}
*
@@ -483,12 +483,12 @@ trait DataFrame extends RDDApi[Row] with Serializable {
/**
* Returns a new [[DataFrame]] by adding a column.
*/
- def addColumn(colName: String, col: Column): DataFrame
+ def withColumn(colName: String, col: Column): DataFrame
/**
* Returns a new [[DataFrame]] with a column renamed.
*/
- def renameColumn(existingName: String, newName: String): DataFrame
+ def withColumnRenamed(existingName: String, newName: String): DataFrame
/**
* Returns the first `n` rows.
@@ -520,6 +520,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R]
+
/**
* Applies a function `f` to all rows.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
new file mode 100644
index 0000000000..a3187fe323
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
@@ -0,0 +1,30 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+/**
+ * A container for a [[DataFrame]], used for implicit conversions.
+ */
+private[sql] case class DataFrameHolder(df: DataFrame) {
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = df
+
+ def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index bb5c6226a2..7b7efbe347 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -94,7 +94,7 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- override def toDataFrame(colNames: String*): DataFrame = {
+ override def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
@@ -229,11 +229,11 @@ private[sql] class DataFrameImpl protected[sql](
}: _*)
}
- override def addColumn(colName: String, col: Column): DataFrame = {
+ override def withColumn(colName: String, col: Column): DataFrame = {
select(Column("*"), col.as(colName))
}
- override def renameColumn(existingName: String, newName: String): DataFrame = {
+ override def withColumnRenamed(existingName: String, newName: String): DataFrame = {
val colNames = schema.map { field =>
val name = field.name
if (name == existingName) Column(name).as(newName) else Column(name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 3c20676355..0868013fe7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import scala.collection.JavaConversions._
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate
*/
class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
- private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
@@ -52,7 +52,12 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
case "max" => Max
case "min" => Min
case "sum" => Sum
- case "count" | "size" => Count
+ case "count" | "size" =>
+ // Turn count(*) into count(1)
+ (inputExpr: Expression) => inputExpr match {
+ case s: Star => Count(Literal(1))
+ case _ => Count(inputExpr)
+ }
}
}
@@ -115,17 +120,17 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
*
- * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
+ * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
*
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense")));
* }}}
*/
@@ -142,7 +147,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Count the number of rows for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
- def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
+ def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index cba3b77011..fc37cfa7a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -50,7 +50,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
protected[sql] override def logicalPlan: LogicalPlan = err()
- override def toDataFrame(colNames: String*): DataFrame = err()
+ override def toDF(colNames: String*): DataFrame = err()
override def schema: StructType = err()
@@ -86,9 +86,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def selectExpr(exprs: String*): DataFrame = err()
- override def addColumn(colName: String, col: Column): DataFrame = err()
+ override def withColumn(colName: String, col: Column): DataFrame = err()
- override def renameColumn(existingName: String, newName: String): DataFrame = err()
+ override def withColumnRenamed(existingName: String, newName: String): DataFrame = err()
override def filter(condition: Column): DataFrame = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 2165949d32..a1736d0277 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,25 @@ class SQLContext(@transient val sparkContext: SparkContext)
object implicits extends Serializable {
// scalastyle:on
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
+ }
+
+ /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
/** Creates a DataFrame from an RDD of case classes or tuples. */
- implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- self.createDataFrame(rdd)
+ implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
+ DataFrameHolder(self.createDataFrame(rdd))
}
/** Creates a DataFrame from a local Seq of Product. */
- implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- self.createDataFrame(data)
+ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
+ {
+ DataFrameHolder(self.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
@@ -198,7 +209,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
// because of [[DoubleRDDFunctions]].
/** Creates a single column DataFrame from an RDD[Int]. */
- implicit def intRddToDataFrame(data: RDD[Int]): DataFrame = {
+ implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -207,11 +218,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/** Creates a single column DataFrame from an RDD[Long]. */
- implicit def longRddToDataFrame(data: RDD[Long]): DataFrame = {
+ implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -220,11 +231,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/** Creates a single column DataFrame from an RDD[String]. */
- implicit def stringRddToDataFrame(data: RDD[String]): DataFrame = {
+ implicit def stringRddToDataFrame(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -233,7 +244,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}
@@ -780,7 +791,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
- createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
+ createDataFrame(catalog.getTables(None)).toDF("tableName", "isTemporary")
}
/**
@@ -789,7 +800,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
- createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
+ createDataFrame(catalog.getTables(Some(databaseName))).toDF("tableName", "isTemporary")
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index c60d407094..ee94a5fdbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
/**
- * A user-defined function. To create one, use the `udf` functions in [[Dsl]].
+ * A user-defined function. To create one, use the `udf` functions in [[functions]].
* As an example:
* {{{
* // Defined a UDF that returns true or false based on some numeric score.
@@ -45,7 +45,7 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
}
/**
- * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
+ * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]].
* This is used by Python API.
*/
private[sql] case class UserDefinedPythonFunction(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 7bc7683576..4a0ec0b72c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -21,6 +21,7 @@ import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -28,17 +29,9 @@ import org.apache.spark.sql.types._
/**
* Domain specific functions available for [[DataFrame]].
*/
-object Dsl {
-
- /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
- implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
-
- /** Converts $"col name" into an [[Column]]. */
- implicit class StringToColumn(val sc: StringContext) extends AnyVal {
- def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args :_*))
- }
- }
+// scalastyle:off
+object functions {
+// scalastyle:on
private[this] implicit def toColumn(expr: Expression): Column = Column(expr)
@@ -104,7 +97,11 @@ object Dsl {
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
/** Aggregate function: returns the number of items in a group. */
- def count(e: Column): Column = Count(e.expr)
+ def count(e: Column): Column = e.expr match {
+ // Turn count(*) into count(1)
+ case s: Star => Count(Literal(1))
+ case _ => Count(e.expr)
+ }
/** Aggregate function: returns the number of items in a group. */
def count(columnName: String): Column = count(Column(columnName))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 8d3e094e33..538d774eb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -90,7 +90,7 @@ trait ParquetTest {
(f: String => Unit): Unit = {
import sqlContext.implicits._
withTempPath { file =>
- sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath)
+ sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
index 639436368c..05233dc5ff 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
@@ -23,7 +23,7 @@ import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataTypes;
-import static org.apache.spark.sql.Dsl.*;
+import static org.apache.spark.sql.functions.*;
/**
* This test doesn't actually run anything. It is here to check the API compatibility for Java.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1318750a4a..691dae0a05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -25,8 +25,9 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
@@ -34,8 +35,6 @@ case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
executedPlan.collect {
@@ -95,7 +94,7 @@ class CachedTableSuite extends QueryTest {
test("too big for memory") {
val data = "*" * 10000
- sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData")
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData")
table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
assert(table("bigData").count() === 200000L)
table("bigData").unpersist(blocking = true)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index e3e6f652ed..a63d733ece 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
@@ -68,7 +68,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("collect on column produced by a binary operator") {
- val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+ val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df("a") + df("b"), Seq(Row(3)))
checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
}
@@ -79,7 +79,7 @@ class ColumnExpressionSuite extends QueryTest {
test("star qualified by data frame object") {
// This is not yet supported.
- val df = testData.toDataFrame
+ val df = testData.toDF
val goldAnswer = df.collect().toSeq
checkAnswer(df.select(df("*")), goldAnswer)
@@ -156,13 +156,13 @@ class ColumnExpressionSuite extends QueryTest {
test("isNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNull),
+ nullStrings.toDF.where($"s".isNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
}
test("isNotNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNotNull),
+ nullStrings.toDF.where($"s".isNotNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index 8fa830dd93..2d2367d6e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -25,31 +25,31 @@ class DataFrameImplicitsSuite extends QueryTest {
test("RDD of tuples") {
checkAnswer(
- sc.parallelize(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
test("Seq of tuples") {
checkAnswer(
- (1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
test("RDD[Int]") {
checkAnswer(
- sc.parallelize(1 to 10).toDataFrame("intCol"),
+ sc.parallelize(1 to 10).toDF("intCol"),
(1 to 10).map(i => Row(i)))
}
test("RDD[Long]") {
checkAnswer(
- sc.parallelize(1L to 10L).toDataFrame("longCol"),
+ sc.parallelize(1L to 10L).toDF("longCol"),
(1L to 10L).map(i => Row(i)))
}
test("RDD[String]") {
checkAnswer(
- sc.parallelize(1 to 10).map(_.toString).toDataFrame("stringCol"),
+ sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
(1 to 10).map(i => Row(i.toString)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 33b35f376b..f0cd43632e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.TestData._
import scala.language.postfixOps
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
@@ -99,7 +99,7 @@ class DataFrameSuite extends QueryTest {
}
test("simple explode") {
- val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")
+ val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
@@ -108,7 +108,7 @@ class DataFrameSuite extends QueryTest {
}
test("explode") {
- val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
+ val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
@@ -141,16 +141,31 @@ class DataFrameSuite extends QueryTest {
testData.select('key).collect().toSeq)
}
- test("agg") {
+ test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b")),
- Seq(Row(1,3), Row(2,3), Row(3,3))
+ Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
+ testData2.groupBy("a").agg(col("a"), count("*")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("*" -> "count")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("b" -> "sum")),
+ Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
+ )
+ }
+
+ test("agg without groups") {
+ checkAnswer(
testData2.agg(sum('b)),
Row(9)
)
@@ -218,20 +233,20 @@ class DataFrameSuite extends QueryTest {
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
- arrayData.orderBy('data.getItem(0).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF.orderBy('data.getItem(0).asc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(0).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF.orderBy('data.getItem(0).desc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF.orderBy('data.getItem(1).asc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
+ arrayData.toDF.orderBy('data.getItem(1).desc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -240,11 +255,11 @@ class DataFrameSuite extends QueryTest {
testData.take(10).toSeq)
checkAnswer(
- arrayData.limit(1),
+ arrayData.toDF.limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
- mapData.limit(1),
+ mapData.toDF.limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
@@ -378,7 +393,7 @@ class DataFrameSuite extends QueryTest {
}
test("addColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
+ val df = testData.toDF.withColumn("newCol", col("key") + 1)
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
@@ -388,8 +403,8 @@ class DataFrameSuite extends QueryTest {
}
test("renameColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
- .renameColumn("value", "valueRenamed")
+ val df = testData.toDF.withColumn("newCol", col("key") + 1)
+ .withColumnRenamed("value", "valueRenamed")
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index f0c939dbb1..fd73065c4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class JoinSuite extends QueryTest with BeforeAndAfterEach {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 5fc35349e1..282b98a987 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -28,7 +28,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.test.TestSQLContext.implicits._
val df =
- sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
before {
df.registerTempTable("ListTablesSuiteTable")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a1c8cf58f2..97684f75e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -1034,10 +1034,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
- rdd1.registerTempTable("nulldata1")
+ rdd1.toDF.registerTempTable("nulldata1")
val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
- rdd2.registerTempTable("nulldata2")
+ rdd2.toDF.registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
(1 to 2).map(i => Row(i)))
@@ -1046,7 +1046,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
- rdd.registerTempTable("distinctData")
+ rdd.toDF.registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 9378261982..9a48f8d063 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectData")
+ rdd.toDF.registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
@@ -93,7 +93,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with nulls") {
val data = NullReflectData(null, null, null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectNullData")
+ rdd.toDF.registerTempTable("reflectNullData")
assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -101,7 +101,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with Nones") {
val data = OptionalReflectData(None, None, None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectOptionalData")
+ rdd.toDF.registerTempTable("reflectOptionalData")
assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -109,7 +109,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
- rdd.registerTempTable("reflectBinary")
+ rdd.toDF.registerTempTable("reflectBinary")
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
@@ -128,7 +128,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None),
Nested(None, "abc")))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectComplexData")
+ rdd.toDF.registerTempTable("reflectComplexData")
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
new GenericRow(Array[Any](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 0ed437edd0..c511eb1469 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test._
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -29,11 +29,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toDataFrame
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +44,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toDataFrame
+ LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +55,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toDataFrame
+ TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +67,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toDataFrame
+ DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,14 +77,14 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toDataFrame
+ BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toDataFrame
+ TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")
val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
@@ -97,7 +97,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toDataFrame
+ UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +106,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toDataFrame
+ LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -114,7 +114,7 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) ::
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
- arrayData.registerTempTable("arrayData")
+ arrayData.toDF.registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])
val mapData =
@@ -124,18 +124,18 @@ object TestData {
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
- mapData.registerTempTable("mapData")
+ mapData.toDF.registerTempTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
- repeatedData.registerTempTable("repeatedData")
+ repeatedData.toDF.registerTempTable("repeatedData")
val nullableRepeatedData =
TestSQLContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
- nullableRepeatedData.registerTempTable("nullableRepeatedData")
+ nullableRepeatedData.toDF.registerTempTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
@@ -144,7 +144,7 @@ object TestData {
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
- )
+ ).toDF
nullInts.registerTempTable("nullInts")
val allNulls =
@@ -152,7 +152,7 @@ object TestData {
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
- NullInts(null) :: Nil)
+ NullInts(null) :: Nil).toDF
allNulls.registerTempTable("allNulls")
case class NullStrings(n: Int, s: String)
@@ -160,11 +160,11 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
- NullStrings(3, null) :: Nil).toDataFrame
+ NullStrings(3, null) :: Nil).toDF
nullStrings.registerTempTable("nullStrings")
case class TableName(tableName: String)
- TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName")
+ TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).toDF.registerTempTable("tableName")
val unparsedStrings =
TestSQLContext.sparkContext.parallelize(
@@ -177,22 +177,22 @@ object TestData {
val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i =>
TimestampField(new Timestamp(i))
})
- timestamps.registerTempTable("timestamps")
+ timestamps.toDF.registerTempTable("timestamps")
case class IntField(i: Int)
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
- withEmptyParts.registerTempTable("withEmptyParts")
+ withEmptyParts.toDF.registerTempTable("withEmptyParts")
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
val person = TestSQLContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
- Person(1, "jim", 20) :: Nil)
+ Person(1, "jim", 20) :: Nil).toDF
person.registerTempTable("person")
val salary = TestSQLContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
- Salary(1, 1000.0) :: Nil)
+ Salary(1, 1000.0) :: Nil).toDF
salary.registerTempTable("salary")
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
@@ -200,6 +200,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toDataFrame
+ :: Nil).toDF
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 95923f9aad..be105c6e83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
+import TestSQLContext.implicits._
case class FunctionResult(f1: String, f2: String)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 3c1657cd5f..5f21d990e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -66,7 +66,7 @@ class UserDefinedTypeSuite extends QueryTest {
val points = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+ val pointsRDD = sparkContext.parallelize(points).toDF()
test("register user type: MyDenseVector for MyLabeledPoint") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 86b1b5fda1..38b0f666ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.columnar
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
@@ -28,8 +29,6 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
@@ -39,7 +38,8 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst")
+ sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ .toDF().registerTempTable("sizeTst")
cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.logical.statistics.sizeInBytes >
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 55a9f735b3..e57bb06e72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -21,13 +21,12 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
@@ -35,7 +34,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
- }, 5)
+ }, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c3210733f1..523be56df6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index b5f13f8bd5..c94e44bd7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -822,7 +823,7 @@ class JsonSuite extends QueryTest {
val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
- val df2 = df1.toDataFrame
+ val df2 = df1.toDF
val result = df2.toJSON.collect()
assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
@@ -843,7 +844,7 @@ class JsonSuite extends QueryTest {
val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
- val df4 = df3.toDataFrame
+ val df4 = df3.toDF
val result2 = df4.toJSON.collect()
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index c8ebbbc7d2..c306330818 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -33,11 +33,12 @@ import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@@ -64,6 +65,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuite extends QueryTest with ParquetTest {
+
val sqlContext = TestSQLContext
/**
@@ -99,12 +101,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
}
test(s"$prefix: fixed-length decimals") {
- import org.apache.spark.sql.test.TestSQLContext.implicits._
def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
+ .toDF
// Parquet doesn't allow column names with spaces, have to add an alias here
.select($"_1" cast decimal as "dec")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 89b18c3439..9fcb04ca23 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -37,7 +37,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val testData = TestHive.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString)))
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
before {
// Since every we are doing tests for DDL statements,
@@ -56,7 +56,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
// Add more data.
@@ -65,7 +65,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq
+ testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq
)
// Now overwrite.
@@ -74,7 +74,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
index 068aa03330..321b784a3f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val df =
- sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 2916724f66..addf887ab9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -28,17 +28,14 @@ import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql._
import org.apache.spark.util.Utils
import org.apache.spark.sql.types._
-
-/* Implicits */
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
/**
* Tests for persisting tables created though the data sources API into the metastore.
*/
class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
override def afterEach(): Unit = {
reset()
if (tempPath.exists()) Utils.deleteRecursively(tempPath)
@@ -154,7 +151,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("check change without refresh") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -170,7 +168,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
// Schema is cached so the new column does not show. The updated values in existing columns
// will show.
@@ -190,7 +189,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("drop, change, recreate") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -206,7 +206,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b", "c") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql("DROP TABLE jsonTable")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 405b200d05..d01dbf80ef 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -567,7 +567,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(2, "str2") :: Nil)
- testData.registerTempTable("REGisteredTABle")
+ testData.toDF.registerTempTable("REGisteredTABle")
assertResult(Array(Row(2, "str2"))) {
sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
@@ -592,7 +592,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") {
val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3))
.zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)}
- TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test")
+ TestHive.sparkContext.parallelize(fixture).toDF.registerTempTable("having_test")
val results =
sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3")
.collect()
@@ -740,7 +740,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(1, "str2") :: Nil)
- testData.registerTempTable("test_describe_commands2")
+ testData.toDF.registerTempTable("test_describe_commands2")
assertResult(
Array(
@@ -900,8 +900,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") {
- sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs")
- sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles")
+ sparkContext.makeRDD(Seq.empty[LogEntry]).toDF.registerTempTable("rawLogs")
+ sparkContext.makeRDD(Seq.empty[LogFile]).toDF.registerTempTable("logFiles")
sql(
"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 029c36aa89..6fc4cc1426 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
test("case insensitivity with scala reflection") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF.registerTempTable("caseSensitivityTest")
val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"),
@@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest {
ignore("case insensitivity with scala reflection joins") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF.registerTempTable("caseSensitivityTest")
sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}
test("nested repeated resolution") {
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("nestedRepeatedTest")
+ .toDF.registerTempTable("nestedRepeatedTest")
assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 8fb5e050a2..ab53c6309e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.util.Utils
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 1e99003d3e..245161d2eb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -111,7 +111,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize(
- IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
+ IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF
testData.registerTempTable("integerTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
@@ -127,7 +127,7 @@ class HiveUdfSuite extends QueryTest {
val testData = TestHive.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
- ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF
testData.registerTempTable("listListIntTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
@@ -142,7 +142,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFListString") {
val testData = TestHive.sparkContext.parallelize(
ListStringCaseClass(Seq("a", "b", "c")) ::
- ListStringCaseClass(Seq("d", "e")) :: Nil)
+ ListStringCaseClass(Seq("d", "e")) :: Nil).toDF
testData.registerTempTable("listStringTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
@@ -156,7 +156,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFStringString") {
val testData = TestHive.sparkContext.parallelize(
- StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
+ StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF
testData.registerTempTable("stringTable")
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
@@ -173,7 +173,7 @@ class HiveUdfSuite extends QueryTest {
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
- Nil)
+ Nil).toDF
testData.registerTempTable("TwoListTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 9a6e8650a0..9788259383 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -34,9 +35,6 @@ case class Nested3(f3: Int)
*/
class SQLQuerySuite extends QueryTest {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
- val sqlCtx = TestHive
-
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
@@ -176,7 +174,8 @@ class SQLQuerySuite extends QueryTest {
}
test("double nested data") {
- sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
+ sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil)
+ .toDF().registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
Row(1))
@@ -199,7 +198,7 @@ class SQLQuerySuite extends QueryTest {
}
test("SPARK-4825 save join to table") {
- val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
sql("CREATE TABLE test1 (key INT, value STRING)")
testData.insertInto("test1")
sql("CREATE TABLE test2 (key INT, value STRING)")
@@ -279,7 +278,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable")
+ TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index a7479a5b95..e246cbb6d7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@@ -152,7 +154,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var normalTableDir: File = null
var partitionedTableDirWithKey: File = null
- import org.apache.spark.sql.hive.test.TestHive.implicits._
override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
@@ -167,12 +168,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDir, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
sparkContext
.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-1"))
+ .toDF()
.saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath)
partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql")
@@ -183,6 +186,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDirWithKey, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetDataWithKey(p, i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
}