aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala10
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala4
-rw-r--r--python/docs/pyspark.sql.rst8
-rw-r--r--python/pyspark/mllib/tests.py2
-rw-r--r--python/pyspark/sql/__init__.py3
-rw-r--r--python/pyspark/sql/context.py34
-rw-r--r--python/pyspark/sql/dataframe.py221
-rw-r--r--python/pyspark/sql/functions.py170
-rw-r--r--python/pyspark/sql/tests.py38
-rwxr-xr-xpython/run-tests3
-rw-r--r--repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala2
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala2
-rw-r--r--repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala)21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala2
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala17
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala6
70 files changed, 596 insertions, 456 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index a2893f78e0..f0241943ef 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -90,7 +90,7 @@ object CrossValidatorExample {
crossval.setNumFolds(2) // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
- val cvModel = crossval.fit(training)
+ val cvModel = crossval.fit(training.toDF)
// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
@@ -100,7 +100,7 @@ object CrossValidatorExample {
Document(7L, "apache hadoop")))
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test)
+ cvModel.transform(test.toDF)
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index aed4423893..54aadd2288 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -58,7 +58,7 @@ object DeveloperApiExample {
lr.setMaxIter(10)
// Learn a LogisticRegression model. This uses the parameters stored in lr.
- val model = lr.fit(training)
+ val model = lr.fit(training.toDF)
// Prepare test data.
val test = sc.parallelize(Seq(
@@ -67,7 +67,7 @@ object DeveloperApiExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data.
- val sumPredictions: Double = model.transform(test)
+ val sumPredictions: Double = model.transform(test.toDF)
.select("features", "label", "prediction")
.collect()
.map { case Row(features: Vector, label: Double, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
index 836ea2e012..adaf796dc1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -137,9 +137,9 @@ object MovieLensALS {
.setRegParam(params.regParam)
.setNumBlocks(params.numBlocks)
- val model = als.fit(training)
+ val model = als.fit(training.toDF)
- val predictions = model.transform(test).cache()
+ val predictions = model.transform(test.toDF).cache()
// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
@@ -158,7 +158,7 @@ object MovieLensALS {
// Inspect false positives.
predictions.registerTempTable("prediction")
- sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
+ sc.textFile(params.movies).map(Movie.parseMovie).toDF.registerTempTable("movie")
sqlContext.sql(
"""
|SELECT userId, prediction.movieId, title, rating, prediction
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 80c9f5ff57..c5bb5515b1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -58,7 +58,7 @@ object SimpleParamsExample {
.setRegParam(0.01)
// Learn a LogisticRegression model. This uses the parameters stored in lr.
- val model1 = lr.fit(training)
+ val model1 = lr.fit(training.toDF)
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
@@ -77,7 +77,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
- val model2 = lr.fit(training, paramMapCombined)
+ val model2 = lr.fit(training.toDF, paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
// Prepare test data.
@@ -90,7 +90,7 @@ object SimpleParamsExample {
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
- model2.transform(test)
+ model2.transform(test.toDF)
.select("features", "label", "myProbability", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index 968cb29212..8b47f88e48 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -69,7 +69,7 @@ object SimpleTextClassificationPipeline {
.setStages(Array(tokenizer, hashingTF, lr))
// Fit the pipeline to training documents.
- val model = pipeline.fit(training)
+ val model = pipeline.fit(training.toDF)
// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
@@ -79,7 +79,7 @@ object SimpleTextClassificationPipeline {
Document(7L, "apache hadoop")))
// Make predictions on test documents.
- model.transform(test)
+ model.transform(test.toDF)
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
index 89b6255991..c98c68a02f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -81,18 +81,18 @@ object DatasetExample {
println(s"Loaded ${origData.count()} instances from file: ${params.input}")
// Convert input data to DataFrame explicitly.
- val df: DataFrame = origData.toDataFrame
+ val df: DataFrame = origData.toDF
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")
- // Select columns, using implicit conversion to DataFrames.
- val labelsDf: DataFrame = origData.select("label")
+ // Select columns
+ val labelsDf: DataFrame = df.select("label")
val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")
- val featuresDf: DataFrame = origData.select("features")
+ val featuresDf: DataFrame = df.select("features")
val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index 1eac3c8d03..79d3d5a24c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -19,7 +19,7 @@ package org.apache.spark.examples.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
@@ -34,10 +34,10 @@ object RDDRelation {
// Importing the SQL context gives access to all the SQL functions and implicit conversions.
import sqlContext.implicits._
- val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
+ val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF
// Any RDD containing case classes can be registered as a table. The schema of the table is
// automatically inferred using scala reflection.
- rdd.registerTempTable("records")
+ df.registerTempTable("records")
// Once tables have been registered, you can run SQL queries over them.
println("Result of SELECT *:")
@@ -55,10 +55,10 @@ object RDDRelation {
rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
// Queries can also be written using a LINQ-like Scala DSL.
- rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
+ df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
// Write out an RDD as a parquet file.
- rdd.saveAsParquetFile("pair.parquet")
+ df.saveAsParquetFile("pair.parquet")
// Read in parquet file. Parquet files are self-describing so the schmema is preserved.
val parquetFile = sqlContext.parquetFile("pair.parquet")
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index 15754cdfcc..7128deba54 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -68,7 +68,7 @@ object HiveFromSpark {
// You can also register RDDs as temporary tables within a HiveContext.
val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
- rdd.registerTempTable("records")
+ rdd.toDF.registerTempTable("records")
// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 2ec2ccdb8c..9a5848684b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
@@ -100,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- dataset.select($"*", callUDF(
- this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol),
+ callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 124ab30f27..c5fc89f935 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -182,24 +182,22 @@ private[ml] object ClassificationModel {
if (map(model.rawPredictionCol) != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT,
- col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(model.rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(model.featuresCol))))
numColsOutput += 1
if (map(model.predictionCol) != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
- tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
- col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol))))
numColsOutput += 1
}
} else if (map(model.predictionCol) != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.select($"*",
- callUDF(features2pred, DoubleType,
- col(map(model.featuresCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(features2pred, DoubleType, col(map(model.featuresCol))))
numColsOutput += 1
}
(numColsOutput, tmpData)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a9a5af5f0f..21f61d80dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
@@ -130,44 +130,39 @@ class LogisticRegressionModel private[ml] (
var numColsOutput = 0
if (map(rawPredictionCol) != "") {
val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))))
numColsOutput += 1
}
if (map(probabilityCol) != "") {
if (map(rawPredictionCol) != "") {
- val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val raw2prob = udf { (rawPreds: Vector) =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
- Vectors.dense(1.0 - prob1, prob1)
+ Vectors.dense(1.0 - prob1, prob1): Vector
}
- tmpData = tmpData.select($"*",
- callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol))))
} else {
- val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
- tmpData = tmpData.select($"*",
- callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
+ tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol))))
}
numColsOutput += 1
}
if (map(predictionCol) != "") {
val t = map(threshold)
if (map(probabilityCol) != "") {
- val predict: Vector => Double = { probs: Vector =>
+ val predict = udf { probs: Vector =>
if (probs(1) > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol))))
} else if (map(rawPredictionCol) != "") {
- val predict: Vector => Double = { rawPreds: Vector =>
+ val predict = udf { rawPreds: Vector =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol))))
} else {
- val predict: Vector => Double = (features: Vector) => this.predict(features)
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ val predict = udf { features: Vector => this.predict(features) }
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol))))
}
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 38518785dc..bd8caac855 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -122,8 +122,8 @@ private[spark] abstract class ProbabilisticClassificationModel[
val features2probs: FeaturesType => Vector = (features) => {
tmpModel.predictProbabilities(features)
}
- outputData.select($"*",
- callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ outputData.withColumn(map(probabilityCol),
+ callUDF(features2probs, new VectorUDT, col(map(featuresCol))))
} else {
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 7623ec59ae..ddbd648d64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -88,7 +88,7 @@ class StandardScalerModel private[ml] (
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
- dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index e416c1eb58..7daeff980f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -216,7 +216,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
val pred: FeaturesType => Double = (features) => {
tmpModel.predict(features)
}
- dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index aac487745f..8d70e4347c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -36,7 +36,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -170,8 +170,8 @@ class ALSModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.implicits._
val map = this.paramMap ++ paramMap
- val users = userFactors.toDataFrame("id", "features")
- val items = itemFactors.toDataFrame("id", "features")
+ val users = userFactors.toDF("id", "features")
+ val items = itemFactors.toDF("id", "features")
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f9142bc226..dd7a9469d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -102,7 +102,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
dataRDD.saveAsParquetFile(dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 1d118963b4..0a358f2e4f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel {
// Create Parquet data.
val data = Data(weights, intercept, threshold)
- sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
+ sc.parallelize(Seq(data), 1).toDF.saveAsParquetFile(Loader.dataPath(path))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index a3a3b5d418..c399496568 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -187,8 +187,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
- model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
- model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+ model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path))
+ model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path))
}
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index f75de6f637..7b27aaa322 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -58,7 +58,7 @@ private[regression] object GLMRegressionModel {
// Create Parquet data.
val data = Data(weights, intercept)
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 373192a20c..5dac62b0c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -197,7 +197,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val nodes = model.topNode.subtreeIterator.toSeq
val dataRDD: DataFrame = sc.parallelize(nodes)
.map(NodeData.apply(0, _))
- .toDataFrame
+ .toDF
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index dbd69dca60..e507f247cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -289,7 +289,7 @@ private[tree] object TreeEnsembleModel {
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
- }.toDataFrame
+ }.toDF
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index cb7d57de35..b118a8dcf1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -358,8 +358,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
val alpha = als.getAlpha
- val model = als.fit(training)
- val predictions = model.transform(test)
+ val model = als.fit(training.toDF)
+ val predictions = model.transform(test.toDF)
.select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index 80c6f02a9d..e03379e521 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -16,3 +16,11 @@ pyspark.sql.types module
:members:
:undoc-members:
:show-inheritance:
+
+
+pyspark.sql.functions module
+------------------------
+.. automodule:: pyspark.sql.functions
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 49e5c9d58e..06207a076e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -335,7 +335,7 @@ class VectorUDTTests(PySparkTestCase):
sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
srdd = sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ schema = srdd.schema
field = [f for f in schema.fields if f.name == "features"][0]
self.assertEqual(field.dataType, self.udt)
vectors = srdd.map(lambda p: p.features).collect()
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 0a5ba00393..b9ffd6945e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -34,9 +34,8 @@ public classes of Spark SQL:
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
-from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
+from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
- 'Dsl',
]
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 082f1b691b..7683c1b4df 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -38,6 +38,25 @@ except ImportError:
__all__ = ["SQLContext", "HiveContext"]
+def _monkey_patch_RDD(sqlCtx):
+ def toDF(self, schema=None, sampleRatio=None):
+ """
+ Convert current :class:`RDD` into a :class:`DataFrame`
+
+ This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)`
+
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd.toDF().collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ return sqlCtx.createDataFrame(self, schema, sampleRatio)
+
+ RDD.toDF = toDF
+
+
class SQLContext(object):
"""Main entry point for Spark SQL functionality.
@@ -49,15 +68,20 @@ class SQLContext(object):
def __init__(self, sparkContext, sqlContext=None):
"""Create a new SQLContext.
+ It will add a method called `toDF` to :class:`RDD`, which could be
+ used to convert an RDD into a DataFrame, it's a shorthand for
+ :func:`SQLContext.createDataFrame`.
+
:param sparkContext: The SparkContext to wrap.
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
>>> from datetime import datetime
+ >>> sqlCtx = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.createDataFrame(allTypes)
+ >>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -70,6 +94,7 @@ class SQLContext(object):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._scala_SQLContext = sqlContext
+ _monkey_patch_RDD(self)
@property
def _ssql_ctx(self):
@@ -442,7 +467,7 @@ class SQLContext(object):
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema)
>>> sqlCtx.registerRDDAsTable(df3, "table2")
>>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
@@ -495,7 +520,7 @@ class SQLContext(object):
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema)
>>> sqlCtx.registerRDDAsTable(df3, "table2")
>>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
@@ -800,7 +825,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
- globs['df'] = sqlCtx.createDataFrame(rdd)
+ _monkey_patch_RDD(sqlCtx)
+ globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b6f052ee44..1438fe5285 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -21,21 +21,19 @@ import warnings
import random
import os
from tempfile import NamedTemporaryFile
-from itertools import imap
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _prepare_for_python_RDD
-from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
- UTF8Deserializer
+from pyspark.rdd import RDD
+from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
-__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]
class DataFrame(object):
@@ -76,6 +74,7 @@ class DataFrame(object):
self.sql_ctx = sql_ctx
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
+ self._schema = None # initialized lazily
@property
def rdd(self):
@@ -86,7 +85,7 @@ class DataFrame(object):
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
- schema = self.schema()
+ schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
@@ -216,14 +215,17 @@ class DataFrame(object):
self._sc._gateway._gateway_client)
self._jdf.save(source, jmode, joptions)
+ @property
def schema(self):
"""Returns the schema of this DataFrame (represented by
a L{StructType}).
- >>> df.schema()
+ >>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
- return _parse_datatype_json_string(self._jdf.schema().json())
+ if self._schema is None:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ return self._schema
def printSchema(self):
"""Prints out the schema in the tree format.
@@ -284,7 +286,7 @@ class DataFrame(object):
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
- cls = _create_cls(self.schema())
+ cls = _create_cls(self.schema)
return [cls(r) for r in rs]
def limit(self, num):
@@ -310,14 +312,26 @@ class DataFrame(object):
return self.limit(num).collect()
def map(self, f):
- """ Return a new RDD by applying a function to each Row, it's a
- shorthand for df.rdd.map()
+ """ Return a new RDD by applying a function to each Row
+
+ It's a shorthand for df.rdd.map()
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
"""
return self.rdd.map(f)
+ def flatMap(self, f):
+ """ Return a new RDD by first applying a function to all elements of this,
+ and then flattening the results.
+
+ It's a shorthand for df.rdd.flatMap()
+
+ >>> df.flatMap(lambda p: p.name).collect()
+ [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
+ """
+ return self.rdd.flatMap(f)
+
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.
@@ -378,21 +392,6 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
- # def takeSample(self, withReplacement, num, seed=None):
- # """Return a fixed-size sampled subset of this DataFrame.
- #
- # >>> df = sqlCtx.inferSchema(rdd)
- # >>> df.takeSample(False, 2, 97)
- # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
- # """
- # seed = seed if seed is not None else random.randint(0, sys.maxint)
- # with SCCallSiteSync(self.context) as css:
- # bytesInJava = self._jdf \
- # .takeSampleToPython(withReplacement, num, long(seed)) \
- # .iterator()
- # cls = _create_cls(self.schema())
- # return map(cls, self._collect_iterator_through_file(bytesInJava))
-
@property
def dtypes(self):
"""Return all column names and their data types as a list.
@@ -400,7 +399,7 @@ class DataFrame(object):
>>> df.dtypes
[('age', 'int'), ('name', 'string')]
"""
- return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
def columns(self):
@@ -409,7 +408,7 @@ class DataFrame(object):
>>> df.columns
[u'age', u'name']
"""
- return [f.name for f in self.schema().fields]
+ return [f.name for f in self.schema.fields]
def join(self, other, joinExprs=None, joinType=None):
"""
@@ -586,8 +585,8 @@ class DataFrame(object):
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.min(df.age)).collect()
+ >>> from pyspark.sql import functions as F
+ >>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
@@ -616,18 +615,18 @@ class DataFrame(object):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
- def addColumn(self, colName, col):
+ def withColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
- >>> df.addColumn('age2', df.age + 2).collect()
+ >>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))
- def renameColumn(self, existing, new):
+ def withColumnRenamed(self, existing, new):
""" Rename an existing column to a new name
- >>> df.renameColumn('age', 'age2').collect()
+ >>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
@@ -635,11 +634,11 @@ class DataFrame(object):
for c in self.columns]
return self.select(*cols)
- def to_pandas(self):
+ def toPandas(self):
"""
Collect all the rows and return a `pandas.DataFrame`.
- >>> df.to_pandas() # doctest: +SKIP
+ >>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
@@ -687,10 +686,11 @@ class GroupedData(object):
name to aggregate methods.
>>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"age": "max"}).collect()
- [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
- >>> from pyspark.sql import Dsl
- >>> gdf.agg(Dsl.min(df.age)).collect()
+ >>> gdf.agg({"*": "count"}).collect()
+ [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+
+ >>> from pyspark.sql import functions as F
+ >>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
@@ -742,12 +742,12 @@ class GroupedData(object):
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.lit(literal)
+ return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.col(name)
+ return sc._jvm.functions.col(name)
def _to_java_column(col):
@@ -767,9 +767,9 @@ def _unary_op(name, doc="unary operator"):
return _
-def _dsl_op(name, doc=''):
+def _func_op(name, doc=''):
def _(self):
- jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+ jc = getattr(self._sc._jvm.functions, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -818,7 +818,7 @@ class Column(DataFrame):
super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
- __neg__ = _dsl_op("negate")
+ __neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
@@ -842,7 +842,7 @@ class Column(DataFrame):
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
- __invert__ = _dsl_op('not')
+ __invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
@@ -920,11 +920,11 @@ class Column(DataFrame):
else:
return 'Column<%s>' % self._jdf.toString()
- def to_pandas(self):
+ def toPandas(self):
"""
Return a pandas.Series from the column
- >>> df.age.to_pandas() # doctest: +SKIP
+ >>> df.age.toPandas() # doctest: +SKIP
0 2
1 5
dtype: int64
@@ -934,123 +934,6 @@ class Column(DataFrame):
return pd.Series(data)
-def _aggregate_func(name, doc=""):
- """ Create a function for aggregator by name"""
- def _(col):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
- return Column(jc)
- _.__name__ = name
- _.__doc__ = doc
- return staticmethod(_)
-
-
-class UserDefinedFunction(object):
- def __init__(self, func, returnType):
- self.func = func
- self.returnType = returnType
- self._broadcast = None
- self._judf = self._create_judf()
-
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
- sc._javaAccumulator, jdt)
- return judf
-
- def __del__(self):
- if self._broadcast is not None:
- self._broadcast.unpersist()
- self._broadcast = None
-
- def __call__(self, *cols):
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
-
-class Dsl(object):
- """
- A collections of builtin aggregators
- """
- DSLS = {
- 'lit': 'Creates a :class:`Column` of literal value.',
- 'col': 'Returns a :class:`Column` based on the given column name.',
- 'column': 'Returns a :class:`Column` based on the given column name.',
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
- 'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
-
- 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
- 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
- 'count': 'Aggregate function: returns the number of items in a group.',
- 'sum': 'Aggregate function: returns the sum of all values in the expression.',
- 'avg': 'Aggregate function: returns the average of the values in a group.',
- 'mean': 'Aggregate function: returns the average of the values in a group.',
- 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
- }
-
- for _name, _doc in DSLS.items():
- locals()[_name] = _aggregate_func(_name, _doc)
- del _name, _doc
-
- @staticmethod
- def countDistinct(col, *cols):
- """ Return a new Column for distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
- [Row(c=2)]
-
- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
- @staticmethod
- def approxCountDistinct(col, rsd=None):
- """ Return a new Column for approxiate distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
- else:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
- return Column(jc)
-
- @staticmethod
- def udf(f, returnType=StringType()):
- """Create a user defined function (UDF)
-
- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
-
-
def _test():
import doctest
from pyspark.context import SparkContext
@@ -1059,11 +942,9 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
- rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
- globs['df'] = sqlCtx.inferSchema(rdd2)
- globs['df2'] = sqlCtx.inferSchema(rdd3)
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
new file mode 100644
index 0000000000..39aa550eeb
--- /dev/null
+++ b/python/pyspark/sql/functions.py
@@ -0,0 +1,170 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin functions
+"""
+
+from itertools import imap
+
+from py4j.java_collections import ListConverter
+
+from pyspark import SparkContext
+from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql.types import StringType
+from pyspark.sql.dataframe import Column, _to_java_column
+
+
+__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
+
+
+def _create_function(name, doc=""):
+ """ Create a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return _
+
+
+_functions = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+}
+
+
+for _name, _doc in _functions.items():
+ globals()[_name] = _create_function(_name, _doc)
+del _name, _doc
+__all__ += _functions.keys()
+
+
+def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of `col` or `cols`
+
+ >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
+ [Row(c=2)]
+
+ >>> df.agg(countDistinct("age", "name").alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
+ jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approximate distinct count of `col`
+
+ >>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+
+class UserDefinedFunction(object):
+ """
+ User defined function in Python
+ """
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.dataframe
+ globs = pyspark.sql.dataframe.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.dataframe, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 43e5c3a1b0..aa80bca346 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -96,7 +96,7 @@ class SQLTests(ReusedPySparkTestCase):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.createDataFrame(rdd)
+ cls.df = rdd.toDF()
@classmethod
def tearDownClass(cls):
@@ -138,7 +138,7 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sqlCtx.jsonRDD(rdd)
df.count()
df.collect()
- df.schema()
+ df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
@@ -155,11 +155,11 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.createDataFrame(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
@@ -195,7 +195,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1, result.head()[0])
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
- self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
@@ -204,8 +204,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -213,8 +212,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -223,9 +221,8 @@ class SQLTests(ReusedPySparkTestCase):
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
- schema = df.schema()
+ df = self.sc.parallelize([row]).toDF()
+ schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
@@ -238,15 +235,14 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.createDataFrame(rdd, schema)
+ df = rdd.toDF(schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.createDataFrame(rdd)
+ df0 = self.sc.parallelize([row]).toDF()
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
@@ -280,10 +276,11 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- from pyspark.sql import Dsl
- self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
- self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ from pyspark.sql import functions
+ self.assertEqual((0, u'99'),
+ tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
def test_save_and_load(self):
df = self.df
@@ -339,8 +336,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
@classmethod
def tearDownClass(cls):
diff --git a/python/run-tests b/python/run-tests
index 077ad60d76..a2c2f37a54 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -35,7 +35,7 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1" | tee -a $LOG_FILE
- SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 >> $LOG_FILE 2>&1
+ SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1
FAILED=$((PIPESTATUS[0]||$FAILED))
@@ -67,6 +67,7 @@ function run_sql_tests() {
run_test "pyspark/sql/types.py"
run_test "pyspark/sql/context.py"
run_test "pyspark/sql/dataframe.py"
+ run_test "pyspark/sql/functions.py"
run_test "pyspark/sql/tests.py"
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 0cf2de6d39..05faef8786 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -137,7 +137,7 @@ private[repl] trait SparkILoopInit {
command("import org.apache.spark.SparkContext._")
command("import sqlContext.implicits._")
command("import sqlContext.sql")
- command("import org.apache.spark.sql.Dsl._")
+ command("import org.apache.spark.sql.functions._")
}
}
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 201f2672d5..529914a2b6 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -262,7 +262,7 @@ class ReplSuite extends FunSuite {
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
- |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect()
+ |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF.collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 1bd2a69914..7a5e94da5c 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -77,7 +77,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
command("import org.apache.spark.SparkContext._")
command("import sqlContext.implicits._")
command("import sqlContext.sql")
- command("import org.apache.spark.sql.Dsl._")
+ command("import org.apache.spark.sql.functions._")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index f959a50564..a7cd4124e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -152,7 +152,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
override lazy val resolved = false
- override def newInstance = this
+ override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9d5d6e78bd..f6ecee1af8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql
-import scala.annotation.tailrec
import scala.language.implicitConversions
-import org.apache.spark.sql.Dsl.lit
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._
@@ -127,7 +126,7 @@ trait Column extends DataFrame {
* df.select( -df("amount") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.select( negate(col("amount") );
* }}}
*/
@@ -140,7 +139,7 @@ trait Column extends DataFrame {
* df.filter( !df("isActive") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( not(df.col("isActive")) );
* }}
*/
@@ -153,7 +152,7 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
@@ -168,7 +167,7 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
*/
@@ -182,7 +181,7 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
*/
@@ -198,7 +197,7 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
*/
@@ -213,7 +212,7 @@ trait Column extends DataFrame {
* people.select( people("age") > 21 )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
*/
@@ -228,7 +227,7 @@ trait Column extends DataFrame {
* people.select( people("age") > lit(21) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4f8f19e2c1..e21e989f36 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -48,7 +48,7 @@ private[sql] object DataFrame {
* }}}
*
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
- * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL.
+ * defined in: [[DataFrame]] (this class), [[Column]], [[functions]] for the DSL.
*
* To select a column from the data frame, use the apply method:
* {{{
@@ -94,27 +94,27 @@ trait DataFrame extends RDDApi[Row] with Serializable {
}
/** Left here for backward compatibility. */
- @deprecated("1.3.0", "use toDataFrame")
+ @deprecated("1.3.0", "use toDF")
def toSchemaRDD: DataFrame = this
/**
* Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
*/
// This is declared with parentheses to prevent the Scala compiler from treating
- // `rdd.toDataFrame("1")` as invoking this toDataFrame and then apply on the returned DataFrame.
- def toDataFrame(): DataFrame = this
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = this
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
- * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
- * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * rdd.toDF // this implicit conversion creates a DataFrame with column name _1 and _2
+ * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
*/
@scala.annotation.varargs
- def toDataFrame(colNames: String*): DataFrame
+ def toDF(colNames: String*): DataFrame
/** Returns the schema of this [[DataFrame]]. */
def schema: StructType
@@ -132,7 +132,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
def explain(extended: Boolean): Unit
/** Only prints the physical plan to the console for debugging purpose. */
- def explain(): Unit = explain(false)
+ def explain(): Unit = explain(extended = false)
/**
* Returns true if the `collect` and `take` methods can be run locally
@@ -179,11 +179,11 @@ trait DataFrame extends RDDApi[Row] with Serializable {
*
* {{{
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df1.join(df2, "outer", $"df1Key" === $"df2Key")
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df1.join(df2, "outer", col("df1Key") === col("df2Key"));
* }}}
*
@@ -483,12 +483,12 @@ trait DataFrame extends RDDApi[Row] with Serializable {
/**
* Returns a new [[DataFrame]] by adding a column.
*/
- def addColumn(colName: String, col: Column): DataFrame
+ def withColumn(colName: String, col: Column): DataFrame
/**
* Returns a new [[DataFrame]] with a column renamed.
*/
- def renameColumn(existingName: String, newName: String): DataFrame
+ def withColumnRenamed(existingName: String, newName: String): DataFrame
/**
* Returns the first `n` rows.
@@ -520,6 +520,7 @@ trait DataFrame extends RDDApi[Row] with Serializable {
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R]
+
/**
* Applies a function `f` to all rows.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
new file mode 100644
index 0000000000..a3187fe323
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
@@ -0,0 +1,30 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+/**
+ * A container for a [[DataFrame]], used for implicit conversions.
+ */
+private[sql] case class DataFrameHolder(df: DataFrame) {
+
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = df
+
+ def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index bb5c6226a2..7b7efbe347 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -94,7 +94,7 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- override def toDataFrame(colNames: String*): DataFrame = {
+ override def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
"Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
@@ -229,11 +229,11 @@ private[sql] class DataFrameImpl protected[sql](
}: _*)
}
- override def addColumn(colName: String, col: Column): DataFrame = {
+ override def withColumn(colName: String, col: Column): DataFrame = {
select(Column("*"), col.as(colName))
}
- override def renameColumn(existingName: String, newName: String): DataFrame = {
+ override def withColumnRenamed(existingName: String, newName: String): DataFrame = {
val colNames = schema.map { field =>
val name = field.name
if (name == existingName) Column(name).as(newName) else Column(name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 3c20676355..0868013fe7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import scala.collection.JavaConversions._
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate
*/
class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
- private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
@@ -52,7 +52,12 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
case "max" => Max
case "min" => Min
case "sum" => Sum
- case "count" | "size" => Count
+ case "count" | "size" =>
+ // Turn count(*) into count(1)
+ (inputExpr: Expression) => inputExpr match {
+ case s: Star => Count(Literal(1))
+ case _ => Count(inputExpr)
+ }
}
}
@@ -115,17 +120,17 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
*
- * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
+ * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
*
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense")));
* }}}
*/
@@ -142,7 +147,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Count the number of rows for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
- def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
+ def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index cba3b77011..fc37cfa7a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -50,7 +50,7 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
protected[sql] override def logicalPlan: LogicalPlan = err()
- override def toDataFrame(colNames: String*): DataFrame = err()
+ override def toDF(colNames: String*): DataFrame = err()
override def schema: StructType = err()
@@ -86,9 +86,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def selectExpr(exprs: String*): DataFrame = err()
- override def addColumn(colName: String, col: Column): DataFrame = err()
+ override def withColumn(colName: String, col: Column): DataFrame = err()
- override def renameColumn(existingName: String, newName: String): DataFrame = err()
+ override def withColumnRenamed(existingName: String, newName: String): DataFrame = err()
override def filter(condition: Column): DataFrame = err()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 2165949d32..a1736d0277 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,25 @@ class SQLContext(@transient val sparkContext: SparkContext)
object implicits extends Serializable {
// scalastyle:on
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
+ }
+
+ /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
/** Creates a DataFrame from an RDD of case classes or tuples. */
- implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- self.createDataFrame(rdd)
+ implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
+ DataFrameHolder(self.createDataFrame(rdd))
}
/** Creates a DataFrame from a local Seq of Product. */
- implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- self.createDataFrame(data)
+ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
+ {
+ DataFrameHolder(self.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
@@ -198,7 +209,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
// because of [[DoubleRDDFunctions]].
/** Creates a single column DataFrame from an RDD[Int]. */
- implicit def intRddToDataFrame(data: RDD[Int]): DataFrame = {
+ implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -207,11 +218,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/** Creates a single column DataFrame from an RDD[Long]. */
- implicit def longRddToDataFrame(data: RDD[Long]): DataFrame = {
+ implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -220,11 +231,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/** Creates a single column DataFrame from an RDD[String]. */
- implicit def stringRddToDataFrame(data: RDD[String]): DataFrame = {
+ implicit def stringRddToDataFrame(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
@@ -233,7 +244,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
row: Row
}
}
- self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}
@@ -780,7 +791,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(): DataFrame = {
- createDataFrame(catalog.getTables(None)).toDataFrame("tableName", "isTemporary")
+ createDataFrame(catalog.getTables(None)).toDF("tableName", "isTemporary")
}
/**
@@ -789,7 +800,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* indicating if a table is a temporary one or not).
*/
def tables(databaseName: String): DataFrame = {
- createDataFrame(catalog.getTables(Some(databaseName))).toDataFrame("tableName", "isTemporary")
+ createDataFrame(catalog.getTables(Some(databaseName))).toDF("tableName", "isTemporary")
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index c60d407094..ee94a5fdbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
/**
- * A user-defined function. To create one, use the `udf` functions in [[Dsl]].
+ * A user-defined function. To create one, use the `udf` functions in [[functions]].
* As an example:
* {{{
* // Defined a UDF that returns true or false based on some numeric score.
@@ -45,7 +45,7 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
}
/**
- * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
+ * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]].
* This is used by Python API.
*/
private[sql] case class UserDefinedPythonFunction(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 7bc7683576..4a0ec0b72c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -21,6 +21,7 @@ import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@@ -28,17 +29,9 @@ import org.apache.spark.sql.types._
/**
* Domain specific functions available for [[DataFrame]].
*/
-object Dsl {
-
- /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
- implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
-
- /** Converts $"col name" into an [[Column]]. */
- implicit class StringToColumn(val sc: StringContext) extends AnyVal {
- def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args :_*))
- }
- }
+// scalastyle:off
+object functions {
+// scalastyle:on
private[this] implicit def toColumn(expr: Expression): Column = Column(expr)
@@ -104,7 +97,11 @@ object Dsl {
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
/** Aggregate function: returns the number of items in a group. */
- def count(e: Column): Column = Count(e.expr)
+ def count(e: Column): Column = e.expr match {
+ // Turn count(*) into count(1)
+ case s: Star => Count(Literal(1))
+ case _ => Count(e.expr)
+ }
/** Aggregate function: returns the number of items in a group. */
def count(columnName: String): Column = count(Column(columnName))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 8d3e094e33..538d774eb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -90,7 +90,7 @@ trait ParquetTest {
(f: String => Unit): Unit = {
import sqlContext.implicits._
withTempPath { file =>
- sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath)
+ sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
index 639436368c..05233dc5ff 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
@@ -23,7 +23,7 @@ import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataTypes;
-import static org.apache.spark.sql.Dsl.*;
+import static org.apache.spark.sql.functions.*;
/**
* This test doesn't actually run anything. It is here to check the API compatibility for Java.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1318750a4a..691dae0a05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -25,8 +25,9 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
@@ -34,8 +35,6 @@ case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
executedPlan.collect {
@@ -95,7 +94,7 @@ class CachedTableSuite extends QueryTest {
test("too big for memory") {
val data = "*" * 10000
- sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData")
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData")
table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
assert(table("bigData").count() === 200000L)
table("bigData").unpersist(blocking = true)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index e3e6f652ed..a63d733ece 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
@@ -68,7 +68,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("collect on column produced by a binary operator") {
- val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+ val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df("a") + df("b"), Seq(Row(3)))
checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
}
@@ -79,7 +79,7 @@ class ColumnExpressionSuite extends QueryTest {
test("star qualified by data frame object") {
// This is not yet supported.
- val df = testData.toDataFrame
+ val df = testData.toDF
val goldAnswer = df.collect().toSeq
checkAnswer(df.select(df("*")), goldAnswer)
@@ -156,13 +156,13 @@ class ColumnExpressionSuite extends QueryTest {
test("isNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNull),
+ nullStrings.toDF.where($"s".isNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
}
test("isNotNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNotNull),
+ nullStrings.toDF.where($"s".isNotNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index 8fa830dd93..2d2367d6e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -25,31 +25,31 @@ class DataFrameImplicitsSuite extends QueryTest {
test("RDD of tuples") {
checkAnswer(
- sc.parallelize(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
test("Seq of tuples") {
checkAnswer(
- (1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
+ (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
(1 to 10).map(i => Row(i, i.toString)))
}
test("RDD[Int]") {
checkAnswer(
- sc.parallelize(1 to 10).toDataFrame("intCol"),
+ sc.parallelize(1 to 10).toDF("intCol"),
(1 to 10).map(i => Row(i)))
}
test("RDD[Long]") {
checkAnswer(
- sc.parallelize(1L to 10L).toDataFrame("longCol"),
+ sc.parallelize(1L to 10L).toDF("longCol"),
(1L to 10L).map(i => Row(i)))
}
test("RDD[String]") {
checkAnswer(
- sc.parallelize(1 to 10).map(_.toString).toDataFrame("stringCol"),
+ sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
(1 to 10).map(i => Row(i.toString)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 33b35f376b..f0cd43632e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.TestData._
import scala.language.postfixOps
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
@@ -99,7 +99,7 @@ class DataFrameSuite extends QueryTest {
}
test("simple explode") {
- val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")
+ val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
@@ -108,7 +108,7 @@ class DataFrameSuite extends QueryTest {
}
test("explode") {
- val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
+ val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
@@ -141,16 +141,31 @@ class DataFrameSuite extends QueryTest {
testData.select('key).collect().toSeq)
}
- test("agg") {
+ test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b")),
- Seq(Row(1,3), Row(2,3), Row(3,3))
+ Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
+ testData2.groupBy("a").agg(col("a"), count("*")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("*" -> "count")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("b" -> "sum")),
+ Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
+ )
+ }
+
+ test("agg without groups") {
+ checkAnswer(
testData2.agg(sum('b)),
Row(9)
)
@@ -218,20 +233,20 @@ class DataFrameSuite extends QueryTest {
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
- arrayData.orderBy('data.getItem(0).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF.orderBy('data.getItem(0).asc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(0).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF.orderBy('data.getItem(0).desc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF.orderBy('data.getItem(1).asc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
+ arrayData.toDF.orderBy('data.getItem(1).desc),
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -240,11 +255,11 @@ class DataFrameSuite extends QueryTest {
testData.take(10).toSeq)
checkAnswer(
- arrayData.limit(1),
+ arrayData.toDF.limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
- mapData.limit(1),
+ mapData.toDF.limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
@@ -378,7 +393,7 @@ class DataFrameSuite extends QueryTest {
}
test("addColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
+ val df = testData.toDF.withColumn("newCol", col("key") + 1)
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
@@ -388,8 +403,8 @@ class DataFrameSuite extends QueryTest {
}
test("renameColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
- .renameColumn("value", "valueRenamed")
+ val df = testData.toDF.withColumn("newCol", col("key") + 1)
+ .withColumnRenamed("value", "valueRenamed")
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index f0c939dbb1..fd73065c4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class JoinSuite extends QueryTest with BeforeAndAfterEach {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 5fc35349e1..282b98a987 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -28,7 +28,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.test.TestSQLContext.implicits._
val df =
- sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
before {
df.registerTempTable("ListTablesSuiteTable")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a1c8cf58f2..97684f75e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -1034,10 +1034,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
- rdd1.registerTempTable("nulldata1")
+ rdd1.toDF.registerTempTable("nulldata1")
val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
- rdd2.registerTempTable("nulldata2")
+ rdd2.toDF.registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
(1 to 2).map(i => Row(i)))
@@ -1046,7 +1046,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
- rdd.registerTempTable("distinctData")
+ rdd.toDF.registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 9378261982..9a48f8d063 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectData")
+ rdd.toDF.registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
@@ -93,7 +93,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with nulls") {
val data = NullReflectData(null, null, null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectNullData")
+ rdd.toDF.registerTempTable("reflectNullData")
assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -101,7 +101,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with Nones") {
val data = OptionalReflectData(None, None, None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectOptionalData")
+ rdd.toDF.registerTempTable("reflectOptionalData")
assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -109,7 +109,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
- rdd.registerTempTable("reflectBinary")
+ rdd.toDF.registerTempTable("reflectBinary")
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
@@ -128,7 +128,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None),
Nested(None, "abc")))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectComplexData")
+ rdd.toDF.registerTempTable("reflectComplexData")
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
new GenericRow(Array[Any](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 0ed437edd0..c511eb1469 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test._
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -29,11 +29,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toDataFrame
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +44,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toDataFrame
+ LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +55,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toDataFrame
+ TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +67,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toDataFrame
+ DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,14 +77,14 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toDataFrame
+ BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toDataFrame
+ TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")
val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
@@ -97,7 +97,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toDataFrame
+ UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +106,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toDataFrame
+ LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -114,7 +114,7 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) ::
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
- arrayData.registerTempTable("arrayData")
+ arrayData.toDF.registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])
val mapData =
@@ -124,18 +124,18 @@ object TestData {
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
- mapData.registerTempTable("mapData")
+ mapData.toDF.registerTempTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
- repeatedData.registerTempTable("repeatedData")
+ repeatedData.toDF.registerTempTable("repeatedData")
val nullableRepeatedData =
TestSQLContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
- nullableRepeatedData.registerTempTable("nullableRepeatedData")
+ nullableRepeatedData.toDF.registerTempTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
@@ -144,7 +144,7 @@ object TestData {
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
- )
+ ).toDF
nullInts.registerTempTable("nullInts")
val allNulls =
@@ -152,7 +152,7 @@ object TestData {
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
- NullInts(null) :: Nil)
+ NullInts(null) :: Nil).toDF
allNulls.registerTempTable("allNulls")
case class NullStrings(n: Int, s: String)
@@ -160,11 +160,11 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
- NullStrings(3, null) :: Nil).toDataFrame
+ NullStrings(3, null) :: Nil).toDF
nullStrings.registerTempTable("nullStrings")
case class TableName(tableName: String)
- TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName")
+ TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).toDF.registerTempTable("tableName")
val unparsedStrings =
TestSQLContext.sparkContext.parallelize(
@@ -177,22 +177,22 @@ object TestData {
val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i =>
TimestampField(new Timestamp(i))
})
- timestamps.registerTempTable("timestamps")
+ timestamps.toDF.registerTempTable("timestamps")
case class IntField(i: Int)
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
- withEmptyParts.registerTempTable("withEmptyParts")
+ withEmptyParts.toDF.registerTempTable("withEmptyParts")
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
val person = TestSQLContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
- Person(1, "jim", 20) :: Nil)
+ Person(1, "jim", 20) :: Nil).toDF
person.registerTempTable("person")
val salary = TestSQLContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
- Salary(1, 1000.0) :: Nil)
+ Salary(1, 1000.0) :: Nil).toDF
salary.registerTempTable("salary")
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
@@ -200,6 +200,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toDataFrame
+ :: Nil).toDF
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 95923f9aad..be105c6e83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
+import TestSQLContext.implicits._
case class FunctionResult(f1: String, f2: String)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 3c1657cd5f..5f21d990e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -66,7 +66,7 @@ class UserDefinedTypeSuite extends QueryTest {
val points = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+ val pointsRDD = sparkContext.parallelize(points).toDF()
test("register user type: MyDenseVector for MyLabeledPoint") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 86b1b5fda1..38b0f666ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.columnar
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
@@ -28,8 +29,6 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
@@ -39,7 +38,8 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst")
+ sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ .toDF().registerTempTable("sizeTst")
cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.logical.statistics.sizeInBytes >
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 55a9f735b3..e57bb06e72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -21,13 +21,12 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
@@ -35,7 +34,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
- }, 5)
+ }, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c3210733f1..523be56df6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index b5f13f8bd5..c94e44bd7c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -822,7 +823,7 @@ class JsonSuite extends QueryTest {
val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
- val df2 = df1.toDataFrame
+ val df2 = df1.toDF
val result = df2.toJSON.collect()
assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
@@ -843,7 +844,7 @@ class JsonSuite extends QueryTest {
val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
- val df4 = df3.toDataFrame
+ val df4 = df3.toDF
val result2 = df4.toJSON.collect()
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index c8ebbbc7d2..c306330818 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -33,11 +33,12 @@ import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@@ -64,6 +65,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuite extends QueryTest with ParquetTest {
+
val sqlContext = TestSQLContext
/**
@@ -99,12 +101,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
}
test(s"$prefix: fixed-length decimals") {
- import org.apache.spark.sql.test.TestSQLContext.implicits._
def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
+ .toDF
// Parquet doesn't allow column names with spaces, have to add an alias here
.select($"_1" cast decimal as "dec")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 89b18c3439..9fcb04ca23 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -37,7 +37,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val testData = TestHive.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString)))
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
before {
// Since every we are doing tests for DDL statements,
@@ -56,7 +56,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
// Add more data.
@@ -65,7 +65,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq
+ testData.toDF.collect().toSeq ++ testData.toDF.collect().toSeq
)
// Now overwrite.
@@ -74,7 +74,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
index 068aa03330..321b784a3f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val df =
- sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDataFrame("key", "value")
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
override def beforeAll(): Unit = {
// The catalog in HiveContext is a case insensitive one.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 2916724f66..addf887ab9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -28,17 +28,14 @@ import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql._
import org.apache.spark.util.Utils
import org.apache.spark.sql.types._
-
-/* Implicits */
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
/**
* Tests for persisting tables created though the data sources API into the metastore.
*/
class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
override def afterEach(): Unit = {
reset()
if (tempPath.exists()) Utils.deleteRecursively(tempPath)
@@ -154,7 +151,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("check change without refresh") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -170,7 +168,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
// Schema is cached so the new column does not show. The updated values in existing columns
// will show.
@@ -190,7 +189,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("drop, change, recreate") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -206,7 +206,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b", "c") :: Nil).toDF
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql("DROP TABLE jsonTable")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 405b200d05..d01dbf80ef 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -567,7 +567,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(2, "str2") :: Nil)
- testData.registerTempTable("REGisteredTABle")
+ testData.toDF.registerTempTable("REGisteredTABle")
assertResult(Array(Row(2, "str2"))) {
sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
@@ -592,7 +592,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") {
val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3))
.zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)}
- TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test")
+ TestHive.sparkContext.parallelize(fixture).toDF.registerTempTable("having_test")
val results =
sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3")
.collect()
@@ -740,7 +740,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(1, "str2") :: Nil)
- testData.registerTempTable("test_describe_commands2")
+ testData.toDF.registerTempTable("test_describe_commands2")
assertResult(
Array(
@@ -900,8 +900,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") {
- sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs")
- sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles")
+ sparkContext.makeRDD(Seq.empty[LogEntry]).toDF.registerTempTable("rawLogs")
+ sparkContext.makeRDD(Seq.empty[LogFile]).toDF.registerTempTable("logFiles")
sql(
"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index 029c36aa89..6fc4cc1426 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
test("case insensitivity with scala reflection") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF.registerTempTable("caseSensitivityTest")
val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"),
@@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest {
ignore("case insensitivity with scala reflection joins") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF.registerTempTable("caseSensitivityTest")
sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}
test("nested repeated resolution") {
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("nestedRepeatedTest")
+ .toDF.registerTempTable("nestedRepeatedTest")
assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 8fb5e050a2..ab53c6309e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.util.Utils
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 1e99003d3e..245161d2eb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -111,7 +111,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize(
- IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
+ IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF
testData.registerTempTable("integerTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
@@ -127,7 +127,7 @@ class HiveUdfSuite extends QueryTest {
val testData = TestHive.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
- ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF
testData.registerTempTable("listListIntTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
@@ -142,7 +142,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFListString") {
val testData = TestHive.sparkContext.parallelize(
ListStringCaseClass(Seq("a", "b", "c")) ::
- ListStringCaseClass(Seq("d", "e")) :: Nil)
+ ListStringCaseClass(Seq("d", "e")) :: Nil).toDF
testData.registerTempTable("listStringTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
@@ -156,7 +156,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFStringString") {
val testData = TestHive.sparkContext.parallelize(
- StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
+ StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF
testData.registerTempTable("stringTable")
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
@@ -173,7 +173,7 @@ class HiveUdfSuite extends QueryTest {
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
- Nil)
+ Nil).toDF
testData.registerTempTable("TwoListTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 9a6e8650a0..9788259383 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -34,9 +35,6 @@ case class Nested3(f3: Int)
*/
class SQLQuerySuite extends QueryTest {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
- val sqlCtx = TestHive
-
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
@@ -176,7 +174,8 @@ class SQLQuerySuite extends QueryTest {
}
test("double nested data") {
- sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
+ sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil)
+ .toDF().registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
Row(1))
@@ -199,7 +198,7 @@ class SQLQuerySuite extends QueryTest {
}
test("SPARK-4825 save join to table") {
- val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
sql("CREATE TABLE test1 (key INT, value STRING)")
testData.insertInto("test1")
sql("CREATE TABLE test2 (key INT, value STRING)")
@@ -279,7 +278,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- sqlCtx.createDataFrame(rowRdd, schema).registerTempTable("testTable")
+ TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index a7479a5b95..e246cbb6d7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -27,6 +27,8 @@ import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@@ -152,7 +154,6 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var normalTableDir: File = null
var partitionedTableDirWithKey: File = null
- import org.apache.spark.sql.hive.test.TestHive.implicits._
override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
@@ -167,12 +168,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDir, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
sparkContext
.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-1"))
+ .toDF()
.saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath)
partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql")
@@ -183,6 +186,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDirWithKey, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetDataWithKey(p, i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
}