aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-10 17:00:17 -0800
committerYin Huai <yhuai@databricks.com>2016-03-10 17:00:17 -0800
commit1d542785b9949e7f92025e6754973a779cc37c52 (patch)
treeceda7492e40c9d9a9231a5011c91e30bf0b1f390
parent27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff)
downloadspark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz
spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2
spark-1d542785b9949e7f92025e6754973a779cc37c52.zip
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request? This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`. Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`). There are several noticeable API changes related to those returning arrays: 1. `collect`/`take` - Old APIs in class `DataFrame`: ```scala def collect(): Array[Row] def take(n: Int): Array[Row] ``` - New APIs in class `Dataset[T]`: ```scala def collect(): Array[T] def take(n: Int): Array[T] def collectRows(): Array[Row] def takeRows(n: Int): Array[Row] ``` Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side. Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here). 1. `randomSplit` - Old APIs in class `DataFrame`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] def randomSplit(weights: Array[Double]): Array[DataFrame] ``` - New APIs in class `Dataset[T]`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] def randomSplit(weights: Array[Double]): Array[Dataset[T]] ``` Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one. 1. `groupBy` Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`. Other noticeable changes: 1. Dataset always do eager analysis now We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`. ## How was this patch tested? Existing tests do the work. ## TODO - [ ] Fix all tests - [ ] Re-enable MiMA check - [ ] Update ScalaDoc (`since`, `group`, and example code) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <liancheng@users.noreply.github.com> Closes #11443 from liancheng/ds-to-df.
-rwxr-xr-xdev/run-tests.py9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java3
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java13
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java13
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java13
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java3
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java5
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java5
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java5
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java15
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java3
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java12
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java11
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java7
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java20
-rw-r--r--examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java5
-rw-r--r--python/pyspark/mllib/common.py2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala532
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala794
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala1
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java60
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java14
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala58
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala138
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala11
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java8
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala1
116 files changed, 1069 insertions, 1444 deletions
diff --git a/dev/run-tests.py b/dev/run-tests.py
index aa6af564be..6e45113134 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -561,10 +561,11 @@ def main():
# spark build
build_apache_spark(build_tool, hadoop_version)
- # backwards compatibility checks
- if build_tool == "sbt":
- # Note: compatibility tests only supported in sbt for now
- detect_binary_inop_with_mima()
+ # TODO Temporarily disable MiMA check for DF-to-DS migration prototyping
+ # # backwards compatibility checks
+ # if build_tool == "sbt":
+ # # Note: compatiblity tests only supported in sbt for now
+ # detect_binary_inop_with_mima()
# run the test suites
run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags)
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
index 69a174562f..39053109da 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
@@ -27,6 +27,7 @@ import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.mllib.linalg.*;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -52,7 +53,7 @@ public class JavaAFTSurvivalRegressionExample {
new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame training = jsql.createDataFrame(data, schema);
+ Dataset<Row> training = jsql.createDataFrame(data, schema);
double[] quantileProbabilities = new double[]{0.3, 0.6};
AFTSurvivalRegression aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
index 90d2ac2b13..9754ba5268 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
@@ -19,6 +19,8 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -93,10 +95,10 @@ public class JavaALSExample {
return Rating.parseRating(str);
}
});
- DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
- DataFrame[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
- DataFrame training = splits[0];
- DataFrame test = splits[1];
+ Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
+ Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
+ Dataset<Row> training = splits[0];
+ Dataset<Row> test = splits[1];
// Build the recommendation model using ALS on the training data
ALS als = new ALS()
@@ -108,8 +110,8 @@ public class JavaALSExample {
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
- DataFrame rawPredictions = model.transform(test);
- DataFrame predictions = rawPredictions
+ Dataset<Row> rawPredictions = model.transform(test);
+ Dataset<Row> predictions = rawPredictions
.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
index 1eda1f694f..84eef1fb8a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
@@ -19,6 +19,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -51,18 +52,18 @@ public class JavaBinarizerExample {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame continuousDataFrame = jsql.createDataFrame(jrdd, schema);
+ Dataset<Row> continuousDataFrame = jsql.createDataFrame(jrdd, schema);
Binarizer binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(0.5);
- DataFrame binarizedDataFrame = binarizer.transform(continuousDataFrame);
- DataFrame binarizedFeatures = binarizedDataFrame.select("binarized_feature");
- for (Row r : binarizedFeatures.collect()) {
+ Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);
+ Dataset<Row> binarizedFeatures = binarizedDataFrame.select("binarized_feature");
+ for (Row r : binarizedFeatures.collectRows()) {
Double binarized_value = r.getDouble(0);
System.out.println(binarized_value);
}
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
index e124c1cf18..1d1a518bbc 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
@@ -30,7 +30,7 @@ import org.apache.spark.ml.clustering.BisectingKMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
@@ -62,7 +62,7 @@ public class JavaBisectingKMeansExample {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
BisectingKMeans bkm = new BisectingKMeans().setK(2);
BisectingKMeansModel model = bkm.fit(dataset);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
index 8ad369cc93..68ffa702ea 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.Bucketizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -53,7 +53,7 @@ public class JavaBucketizerExample {
StructType schema = new StructType(new StructField[]{
new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame dataFrame = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataFrame = jsql.createDataFrame(data, schema);
Bucketizer bucketizer = new Bucketizer()
.setInputCol("features")
@@ -61,7 +61,7 @@ public class JavaBucketizerExample {
.setSplits(splits);
// Transform original data into its bucket index.
- DataFrame bucketedData = bucketizer.transform(dataFrame);
+ Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
bucketedData.show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
index ede05d6e20..b1bf1cfeb2 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -28,7 +29,6 @@ import java.util.Arrays;
import org.apache.spark.ml.feature.ChiSqSelector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -55,7 +55,7 @@ public class JavaChiSqSelectorExample {
new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
ChiSqSelector selector = new ChiSqSelector()
.setNumTopFeatures(1)
@@ -63,7 +63,7 @@ public class JavaChiSqSelectorExample {
.setLabelCol("clicked")
.setOutputCol("selectedFeatures");
- DataFrame result = selector.fit(df).transform(df);
+ Dataset<Row> result = selector.fit(df).transform(df);
result.show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java
index 872e5a07d1..ec3ac202be 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java
@@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -48,7 +48,7 @@ public class JavaCountVectorizerExample {
StructType schema = new StructType(new StructField [] {
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
// fit a CountVectorizerModel from the corpus
CountVectorizerModel cvModel = new CountVectorizer()
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 9bbc14ea40..fb6c47be39 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -34,6 +34,7 @@ import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -71,7 +72,8 @@ public class JavaCrossValidatorExample {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
+ Dataset<Row> training = jsql.createDataFrame(
+ jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -112,11 +114,11 @@ public class JavaCrossValidatorExample {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
+ Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- DataFrame predictions = cvModel.transform(test);
- for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
+ Dataset<Row> predictions = cvModel.transform(test);
+ for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
index 35c0d534a4..4b15fde9c3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
@@ -19,6 +19,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -28,7 +29,6 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.DCT;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.Metadata;
@@ -51,12 +51,12 @@ public class JavaDCTExample {
StructType schema = new StructType(new StructField[]{
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
- DataFrame df = jsql.createDataFrame(data, schema);
+ Dataset<Row> df = jsql.createDataFrame(data, schema);
DCT dct = new DCT()
.setInputCol("features")
.setOutputCol("featuresDCT")
.setInverse(false);
- DataFrame dctDf = dct.transform(df);
+ Dataset<Row> dctDf = dct.transform(df);
dctDf.select("featuresDCT").show(3);
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
index b5347b7650..5bd61fe508 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
@@ -26,7 +26,8 @@ import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -38,7 +39,7 @@ public class JavaDecisionTreeClassificationExample {
// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
@@ -55,9 +56,9 @@ public class JavaDecisionTreeClassificationExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
DecisionTreeClassifier dt = new DecisionTreeClassifier()
@@ -78,7 +79,7 @@ public class JavaDecisionTreeClassificationExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
index 9cb67be04a..a4f3e97bf3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
@@ -27,7 +27,8 @@ import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -38,7 +39,7 @@ public class JavaDecisionTreeRegressionExample {
SQLContext sqlContext = new SQLContext(jsc);
// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm")
+ Dataset<Row> data = sqlContext.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them.
@@ -50,9 +51,9 @@ public class JavaDecisionTreeRegressionExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
@@ -66,7 +67,7 @@ public class JavaDecisionTreeRegressionExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index da2012ad51..e568bea607 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -61,7 +62,8 @@ public class JavaDeveloperApiExample {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
+ Dataset<Row> training = jsql.createDataFrame(
+ jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
@@ -79,12 +81,12 @@ public class JavaDeveloperApiExample {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
+ Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- DataFrame results = model.transform(test);
+ Dataset<Row> results = model.transform(test);
double sumPredictions = 0;
- for (Row r : results.select("features", "label", "prediction").collect()) {
+ for (Row r : results.select("features", "label", "prediction").collectRows()) {
sumPredictions += r.getDouble(2);
}
if (sumPredictions != 0.0) {
@@ -145,7 +147,7 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(DataFrame dataset) {
+ public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) {
// Extract columns from data using helper method.
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java
index c1f00dde0e..37de9cf359 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java
@@ -19,6 +19,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -31,7 +32,6 @@ import org.apache.spark.ml.feature.ElementwiseProduct;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -58,7 +58,7 @@ public class JavaElementwiseProductExample {
StructType schema = DataTypes.createStructType(fields);
- DataFrame dataFrame = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> dataFrame = sqlContext.createDataFrame(jrdd, schema);
Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0);
@@ -72,4 +72,4 @@ public class JavaElementwiseProductExample {
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
index 44cf3507f3..8a02f60aa4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
@@ -30,6 +30,7 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// $example off$
import org.apache.spark.sql.SQLContext;
@@ -48,7 +49,7 @@ public class JavaEstimatorTransformerParamExample {
// Prepare training data.
// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into
// DataFrames, where it uses the bean metadata to infer the schema.
- DataFrame training = sqlContext.createDataFrame(
+ Dataset<Row> training = sqlContext.createDataFrame(
Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -89,7 +90,7 @@ public class JavaEstimatorTransformerParamExample {
System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
- DataFrame test = sqlContext.createDataFrame(Arrays.asList(
+ Dataset<Row> test = sqlContext.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
@@ -99,8 +100,8 @@ public class JavaEstimatorTransformerParamExample {
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
- DataFrame results = model2.transform(test);
- for (Row r : results.select("features", "label", "myProbability", "prediction").collect()) {
+ Dataset<Row> results = model2.transform(test);
+ for (Row r : results.select("features", "label", "myProbability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
index 848fe6566c..c2cb955385 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
@@ -27,7 +27,8 @@ import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -39,7 +40,7 @@ public class JavaGradientBoostedTreeClassifierExample {
// $example on$
// Load and parse the data file, converting it to a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
@@ -56,9 +57,9 @@ public class JavaGradientBoostedTreeClassifierExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a GBT model.
GBTClassifier gbt = new GBTClassifier()
@@ -80,7 +81,7 @@ public class JavaGradientBoostedTreeClassifierExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java
index 1f67b0842d..83fd89e3bd 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java
@@ -28,7 +28,8 @@ import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.GBTRegressionModel;
import org.apache.spark.ml.regression.GBTRegressor;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -40,7 +41,8 @@ public class JavaGradientBoostedTreeRegressorExample {
// $example on$
// Load and parse the data file, converting it to a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data =
+ sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
@@ -51,9 +53,9 @@ public class JavaGradientBoostedTreeRegressorExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a GBT model.
GBTRegressor gbt = new GBTRegressor()
@@ -68,7 +70,7 @@ public class JavaGradientBoostedTreeRegressorExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("prediction", "label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
index 3ccd699326..9b8c22f3bd 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -28,7 +29,6 @@ import java.util.Arrays;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -56,18 +56,18 @@ public class JavaIndexToStringExample {
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
StringIndexerModel indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df);
- DataFrame indexed = indexer.transform(df);
+ Dataset<Row> indexed = indexer.transform(df);
IndexToString converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory");
- DataFrame converted = converter.transform(indexed);
+ Dataset<Row> converted = converter.transform(indexed);
converted.select("id", "originalCategory").show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
index 96481d882a..30ccf30885 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -23,6 +23,7 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$
@@ -81,7 +82,7 @@ public class JavaKMeansExample {
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields);
- DataFrame dataset = sqlContext.createDataFrame(points, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(points, schema);
// Trains a k-means model
KMeans kmeans = new KMeans()
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
index 3a5d3237c8..c70d44c297 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java
@@ -29,6 +29,7 @@ import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
@@ -75,7 +76,7 @@ public class JavaLDAExample {
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParseVector());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields);
- DataFrame dataset = sqlContext.createDataFrame(points, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(points, schema);
// Trains a LDA model
LDA lda = new LDA()
@@ -87,7 +88,7 @@ public class JavaLDAExample {
System.out.println(model.logPerplexity(dataset));
// Shows the result
- DataFrame topics = model.describeTopics(3);
+ Dataset<Row> topics = model.describeTopics(3);
topics.show(false);
model.transform(dataset).show(false);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
index 4ad7676c8d..08fce89359 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
@@ -24,7 +24,8 @@ import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -36,7 +37,7 @@ public class JavaLinearRegressionWithElasticNetExample {
// $example on$
// Load training data
- DataFrame training = sqlContext.read().format("libsvm")
+ Dataset<Row> training = sqlContext.read().format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt");
LinearRegression lr = new LinearRegression()
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
index 986f3b3b28..73b028fb44 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
@@ -24,7 +24,8 @@ import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
// $example off$
@@ -36,7 +37,7 @@ public class JavaLogisticRegressionSummaryExample {
SQLContext sqlContext = new SQLContext(jsc);
// Load training data
- DataFrame training = sqlContext.read().format("libsvm")
+ Dataset<Row> training = sqlContext.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression()
@@ -65,14 +66,14 @@ public class JavaLogisticRegressionSummaryExample {
(BinaryLogisticRegressionSummary) trainingSummary;
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
- DataFrame roc = binarySummary.roc();
+ Dataset<Row> roc = binarySummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
- DataFrame fMeasure = binarySummary.fMeasureByThreshold();
+ Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
.select("threshold").head().getDouble(0);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
index 1d28279d72..6911668522 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
@@ -22,7 +22,8 @@ import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -34,7 +35,7 @@ public class JavaLogisticRegressionWithElasticNetExample {
// $example on$
// Load training data
- DataFrame training = sqlContext.read().format("libsvm")
+ Dataset<Row> training = sqlContext.read().format("libsvm")
.load("data/mllib/sample_libsvm_data.txt");
LogisticRegression lr = new LogisticRegression()
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
index 2d50ba7faa..4aee18eeab 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
@@ -24,7 +24,8 @@ import org.apache.spark.sql.SQLContext;
// $example on$
import org.apache.spark.ml.feature.MinMaxScaler;
import org.apache.spark.ml.feature.MinMaxScalerModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
// $example off$
public class JavaMinMaxScalerExample {
@@ -34,7 +35,7 @@ public class JavaMinMaxScalerExample {
SQLContext jsql = new SQLContext(jsc);
// $example on$
- DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
MinMaxScaler scaler = new MinMaxScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");
@@ -43,9 +44,9 @@ public class JavaMinMaxScalerExample {
MinMaxScalerModel scalerModel = scaler.fit(dataFrame);
// rescale each feature to range [min, max].
- DataFrame scaledData = scalerModel.transform(dataFrame);
+ Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.show();
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
index 87ad119491..e394605db7 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
@@ -34,7 +34,7 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// $example off$
import org.apache.spark.sql.SQLContext;
@@ -51,7 +51,7 @@ public class JavaModelSelectionViaCrossValidationExample {
// $example on$
// Prepare training documents, which are labeled.
- DataFrame training = sqlContext.createDataFrame(Arrays.asList(
+ Dataset<Row> training = sqlContext.createDataFrame(Arrays.asList(
new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
new JavaLabeledDocument(1L, "b d", 0.0),
new JavaLabeledDocument(2L,"spark f g h", 1.0),
@@ -102,7 +102,7 @@ public class JavaModelSelectionViaCrossValidationExample {
CrossValidatorModel cvModel = cv.fit(training);
// Prepare test documents, which are unlabeled.
- DataFrame test = sqlContext.createDataFrame(Arrays.asList(
+ Dataset<Row> test = sqlContext.createDataFrame(Arrays.asList(
new JavaDocument(4L, "spark i j k"),
new JavaDocument(5L, "l m n"),
new JavaDocument(6L, "mapreduce spark"),
@@ -110,8 +110,8 @@ public class JavaModelSelectionViaCrossValidationExample {
), JavaDocument.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- DataFrame predictions = cvModel.transform(test);
- for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) {
+ Dataset<Row> predictions = cvModel.transform(test);
+ for (Row r : predictions.select("id", "text", "probability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
index 77adb02dfd..6ac4aea3c4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
@@ -26,7 +26,8 @@ import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
// $example off$
import org.apache.spark.sql.SQLContext;
@@ -41,13 +42,13 @@ public class JavaModelSelectionViaTrainValidationSplitExample {
SQLContext jsql = new SQLContext(sc);
// $example on$
- DataFrame data = jsql.read().format("libsvm")
+ Dataset<Row> data = jsql.read().format("libsvm")
.load("data/mllib/sample_linear_regression_data.txt");
// Prepare training and test data.
- DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345);
- DataFrame training = splits[0];
- DataFrame test = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345);
+ Dataset<Row> training = splits[0];
+ Dataset<Row> test = splits[1];
LinearRegression lr = new LinearRegression();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
index 84369f6681..0ca528d8cd 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
@@ -20,11 +20,12 @@ package org.apache.spark.examples.ml;
// $example on$
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
-import org.apache.spark.sql.DataFrame;
// $example off$
/**
@@ -40,11 +41,11 @@ public class JavaMultilayerPerceptronClassifierExample {
// $example on$
// Load training data
String path = "data/mllib/sample_multiclass_classification_data.txt";
- DataFrame dataFrame = jsql.read().format("libsvm").load(path);
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load(path);
// Split the data into train and test
- DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
- DataFrame train = splits[0];
- DataFrame test = splits[1];
+ Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
+ Dataset<Row> train = splits[0];
+ Dataset<Row> test = splits[1];
// specify layers for the neural network:
// input layer of size 4 (features), two intermediate of size 5 and 4
// and output of size 3 (classes)
@@ -58,8 +59,8 @@ public class JavaMultilayerPerceptronClassifierExample {
// train the model
MultilayerPerceptronClassificationModel model = trainer.fit(train);
// compute precision on the test set
- DataFrame result = model.transform(test);
- DataFrame predictionAndLabels = result.select("prediction", "label");
+ Dataset<Row> result = model.transform(test);
+ Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("precision");
System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
index 8fd75ed8b5..0305f737ca 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
@@ -19,6 +19,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
// $example on$
@@ -26,7 +27,6 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.NGram;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -54,13 +54,13 @@ public class JavaNGramExample {
"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame wordDataFrame = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> wordDataFrame = sqlContext.createDataFrame(jrdd, schema);
NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams");
- DataFrame ngramDataFrame = ngramTransformer.transform(wordDataFrame);
+ Dataset<Row> ngramDataFrame = ngramTransformer.transform(wordDataFrame);
- for (Row r : ngramDataFrame.select("ngrams", "label").take(3)) {
+ for (Row r : ngramDataFrame.select("ngrams", "label").takeRows(3)) {
java.util.List<String> ngrams = r.getList(0);
for (String ngram : ngrams) System.out.print(ngram + " --- ");
System.out.println();
@@ -68,4 +68,4 @@ public class JavaNGramExample {
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
index ed3f6163c0..31cd752136 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
@@ -23,7 +23,8 @@ import org.apache.spark.sql.SQLContext;
// $example on$
import org.apache.spark.ml.feature.Normalizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
// $example off$
public class JavaNormalizerExample {
@@ -33,7 +34,7 @@ public class JavaNormalizerExample {
SQLContext jsql = new SQLContext(jsc);
// $example on$
- DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Normalize each Vector using $L^1$ norm.
Normalizer normalizer = new Normalizer()
@@ -41,14 +42,14 @@ public class JavaNormalizerExample {
.setOutputCol("normFeatures")
.setP(1.0);
- DataFrame l1NormData = normalizer.transform(dataFrame);
+ Dataset<Row> l1NormData = normalizer.transform(dataFrame);
l1NormData.show();
// Normalize each Vector using $L^\infty$ norm.
- DataFrame lInfNormData =
+ Dataset<Row> lInfNormData =
normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
lInfNormData.show();
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
index bc50960708..882438ca28 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -58,18 +58,18 @@ public class JavaOneHotEncoderExample {
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
StringIndexerModel indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df);
- DataFrame indexed = indexer.transform(df);
+ Dataset<Row> indexed = indexer.transform(df);
OneHotEncoder encoder = new OneHotEncoder()
.setInputCol("categoryIndex")
.setOutputCol("categoryVec");
- DataFrame encoded = encoder.transform(indexed);
+ Dataset<Row> encoded = encoder.transform(indexed);
encoded.select("id", "categoryVec").show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
index 42374e77ac..8288f73c1b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -30,6 +30,8 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;
// $example off$
@@ -81,9 +83,9 @@ public class JavaOneVsRestExample {
OneVsRest ovr = new OneVsRest().setClassifier(classifier);
String input = params.input;
- DataFrame inputData = jsql.read().format("libsvm").load(input);
- DataFrame train;
- DataFrame test;
+ Dataset<Row> inputData = jsql.read().format("libsvm").load(input);
+ Dataset<Row> train;
+ Dataset<Row> test;
// compute the train/ test split: if testInput is not provided use part of input
String testInput = params.testInput;
@@ -95,7 +97,7 @@ public class JavaOneVsRestExample {
String.valueOf(numFeatures)).load(testInput);
} else {
double f = params.fracTest;
- DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
+ Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
train = tmp[0];
test = tmp[1];
}
@@ -104,7 +106,7 @@ public class JavaOneVsRestExample {
OneVsRestModel ovrModel = ovr.fit(train.cache());
// score the model on test data
- DataFrame predictions = ovrModel.transform(test.cache())
+ Dataset<Row> predictions = ovrModel.transform(test.cache())
.select("prediction", "label");
// obtain metrics
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
index 8282fab084..a792fd7d47 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
@@ -29,7 +29,7 @@ import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.PCAModel;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.Metadata;
@@ -54,7 +54,7 @@ public class JavaPCAExample {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
- DataFrame df = jsql.createDataFrame(data, schema);
+ Dataset<Row> df = jsql.createDataFrame(data, schema);
PCAModel pca = new PCA()
.setInputCol("features")
@@ -62,7 +62,7 @@ public class JavaPCAExample {
.setK(3)
.fit(df);
- DataFrame result = pca.transform(df).select("pcaFeatures");
+ Dataset<Row> result = pca.transform(df).select("pcaFeatures");
result.show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java
index 3407c25c83..6ae418d564 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java
@@ -30,7 +30,7 @@ import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// $example off$
import org.apache.spark.sql.SQLContext;
@@ -46,7 +46,7 @@ public class JavaPipelineExample {
// $example on$
// Prepare training documents, which are labeled.
- DataFrame training = sqlContext.createDataFrame(Arrays.asList(
+ Dataset<Row> training = sqlContext.createDataFrame(Arrays.asList(
new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
new JavaLabeledDocument(1L, "b d", 0.0),
new JavaLabeledDocument(2L, "spark f g h", 1.0),
@@ -71,7 +71,7 @@ public class JavaPipelineExample {
PipelineModel model = pipeline.fit(training);
// Prepare test documents, which are unlabeled.
- DataFrame test = sqlContext.createDataFrame(Arrays.asList(
+ Dataset<Row> test = sqlContext.createDataFrame(Arrays.asList(
new JavaDocument(4L, "spark i j k"),
new JavaDocument(5L, "l m n"),
new JavaDocument(6L, "mapreduce spark"),
@@ -79,8 +79,8 @@ public class JavaPipelineExample {
), JavaDocument.class);
// Make predictions on test documents.
- DataFrame predictions = model.transform(test);
- for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) {
+ Dataset<Row> predictions = model.transform(test);
+ for (Row r : predictions.select("id", "text", "probability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
index 668f71e640..5a4064c604 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.PolynomialExpansion;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.Metadata;
@@ -58,14 +58,14 @@ public class JavaPolynomialExpansionExample {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
- DataFrame df = jsql.createDataFrame(data, schema);
- DataFrame polyDF = polyExpansion.transform(df);
+ Dataset<Row> df = jsql.createDataFrame(data, schema);
+ Dataset<Row> polyDF = polyExpansion.transform(df);
- Row[] row = polyDF.select("polyFeatures").take(3);
+ Row[] row = polyDF.select("polyFeatures").takeRows(3);
for (Row r : row) {
System.out.println(r.get(0));
}
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java
index 251ae79d9a..7b226fede9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java
@@ -25,7 +25,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.QuantileDiscretizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -56,14 +56,14 @@ public class JavaQuantileDiscretizerExample {
new StructField("hour", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
QuantileDiscretizer discretizer = new QuantileDiscretizer()
.setInputCol("hour")
.setOutputCol("result")
.setNumBuckets(3);
- DataFrame result = discretizer.fit(df).transform(df);
+ Dataset<Row> result = discretizer.fit(df).transform(df);
result.show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java
index 1e1062b541..8c453bf80d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.RFormula;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.StructField;
@@ -55,12 +55,12 @@ public class JavaRFormulaExample {
RowFactory.create(9, "NZ", 15, 0.0)
));
- DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(rdd, schema);
RFormula formula = new RFormula()
.setFormula("clicked ~ country + hour")
.setFeaturesCol("features")
.setLabelCol("label");
- DataFrame output = formula.fit(dataset).transform(dataset);
+ Dataset<Row> output = formula.fit(dataset).transform(dataset);
output.select("features", "label").show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java
index 5a62496660..05c2bc9622 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java
@@ -27,7 +27,8 @@ import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -39,7 +40,8 @@ public class JavaRandomForestClassifierExample {
// $example on$
// Load and parse the data file, converting it to a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data =
+ sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
@@ -56,9 +58,9 @@ public class JavaRandomForestClassifierExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a RandomForest model.
RandomForestClassifier rf = new RandomForestClassifier()
@@ -79,7 +81,7 @@ public class JavaRandomForestClassifierExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java
index 05782a0724..d366967083 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java
@@ -28,7 +28,8 @@ import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -40,7 +41,8 @@ public class JavaRandomForestRegressorExample {
// $example on$
// Load and parse the data file, converting it to a DataFrame.
- DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data =
+ sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
@@ -51,9 +53,9 @@ public class JavaRandomForestRegressorExample {
.fit(data);
// Split the data into training and test sets (30% held out for testing)
- DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
- DataFrame trainingData = splits[0];
- DataFrame testData = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
+ Dataset<Row> trainingData = splits[0];
+ Dataset<Row> testData = splits[1];
// Train a RandomForest model.
RandomForestRegressor rf = new RandomForestRegressor()
@@ -68,7 +70,7 @@ public class JavaRandomForestRegressorExample {
PipelineModel model = pipeline.fit(trainingData);
// Make predictions.
- DataFrame predictions = model.transform(testData);
+ Dataset<Row> predictions = model.transform(testData);
// Select example rows to display.
predictions.select("prediction", "label", "features").show(5);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java
index a9d64d5e3f..e413cbaf71 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java
@@ -25,6 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -48,7 +49,7 @@ public class JavaSQLTransformerExample {
new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("v2", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
SQLTransformer sqlTrans = new SQLTransformer().setStatement(
"SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__");
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index ea83e8fef9..52bb4ec050 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -54,7 +54,8 @@ public class JavaSimpleParamsExample {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
+ Dataset<Row> training =
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -95,14 +96,14 @@ public class JavaSimpleParamsExample {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
+ Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegressionModel.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
- DataFrame results = model2.transform(test);
- for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) {
+ Dataset<Row> results = model2.transform(test);
+ for (Row r: results.select("features", "label", "myProbability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 54738813d0..9bd543c44f 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -29,7 +29,7 @@ import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -54,7 +54,8 @@ public class JavaSimpleTextClassificationPipeline {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
+ Dataset<Row> training =
+ jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -79,11 +80,11 @@ public class JavaSimpleTextClassificationPipeline {
new Document(5L, "l m n"),
new Document(6L, "spark hadoop spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
+ Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
- DataFrame predictions = model.transform(test);
- for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
+ Dataset<Row> predictions = model.transform(test);
+ for (Row r: predictions.select("id", "text", "probability", "prediction").collectRows()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java
index da4756643f..e2dd759c0a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java
@@ -24,7 +24,8 @@ import org.apache.spark.sql.SQLContext;
// $example on$
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.ml.feature.StandardScalerModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
// $example off$
public class JavaStandardScalerExample {
@@ -34,7 +35,7 @@ public class JavaStandardScalerExample {
SQLContext jsql = new SQLContext(jsc);
// $example on$
- DataFrame dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
@@ -46,9 +47,9 @@ public class JavaStandardScalerExample {
StandardScalerModel scalerModel = scaler.fit(dataFrame);
// Normalize each feature to have unit standard deviation.
- DataFrame scaledData = scalerModel.transform(dataFrame);
+ Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.show();
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java
index b6b201c6b6..0ff3782cb3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.StopWordsRemover;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -57,7 +57,7 @@ public class JavaStopWordsRemoverExample {
"raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(rdd, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(rdd, schema);
remover.transform(dataset).show();
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java
index 05d12c1e70..ceacbb4fb3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java
@@ -26,7 +26,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.StringIndexer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.StructField;
@@ -54,13 +54,13 @@ public class JavaStringIndexerExample {
createStructField("id", IntegerType, false),
createStructField("category", StringType, false)
});
- DataFrame df = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(jrdd, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex");
- DataFrame indexed = indexer.fit(df).transform(df);
+ Dataset<Row> indexed = indexer.fit(df).transform(df);
indexed.show();
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
index a41a5ec9bf..fd1ce424bf 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -54,19 +54,19 @@ public class JavaTfIdfExample {
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> sentenceData = sqlContext.createDataFrame(jrdd, schema);
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
- DataFrame wordsData = tokenizer.transform(sentenceData);
+ Dataset<Row> wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurizedData = hashingTF.transform(wordsData);
+ Dataset<Row> featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
- DataFrame rescaledData = idfModel.transform(featurizedData);
- for (Row r : rescaledData.select("features", "label").take(3)) {
+ Dataset<Row> rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").takeRows(3)) {
Vector features = r.getAs(0);
Double label = r.getDouble(1);
System.out.println(features);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java
index 617dc3f66e..a2f8c436e3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java
@@ -27,7 +27,7 @@ import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
@@ -54,12 +54,12 @@ public class JavaTokenizerExample {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema);
Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
- DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
- for (Row r : wordsDataFrame.select("words", "label"). take(3)) {
+ Dataset<Row> wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+ for (Row r : wordsDataFrame.select("words", "label").takeRows(3)) {
java.util.List<String> words = r.getList(0);
for (String word : words) System.out.print(word + " ");
System.out.println();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
index d433905fc8..09bbc39c01 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
@@ -23,7 +23,8 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
/**
@@ -44,12 +45,12 @@ public class JavaTrainValidationSplitExample {
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
- DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Prepare training and test data.
- DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
- DataFrame training = splits[0];
- DataFrame test = splits[1];
+ Dataset<Row>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
+ Dataset<Row> training = splits[0];
+ Dataset<Row> test = splits[1];
LinearRegression lr = new LinearRegression();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
index 7e230b5897..953ad455b1 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.*;
@@ -52,13 +52,13 @@ public class JavaVectorAssemblerExample {
});
Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0);
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
- DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(rdd, schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"hour", "mobile", "userFeatures"})
.setOutputCol("features");
- DataFrame output = assembler.transform(dataset);
+ Dataset<Row> output = assembler.transform(dataset);
System.out.println(output.select("features", "clicked").first());
// $example off$
jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java
index 545758e31d..b3b5953ee7 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java
@@ -26,7 +26,8 @@ import java.util.Map;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
// $example off$
public class JavaVectorIndexerExample {
@@ -36,7 +37,7 @@ public class JavaVectorIndexerExample {
SQLContext jsql = new SQLContext(jsc);
// $example on$
- DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
@@ -53,9 +54,9 @@ public class JavaVectorIndexerExample {
System.out.println();
// Create new column "indexed" with categorical values transformed to indices
- DataFrame indexedData = indexerModel.transform(data);
+ Dataset<Row> indexedData = indexerModel.transform(data);
indexedData.show();
// $example off$
jsc.stop();
}
-} \ No newline at end of file
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
index 4d5cb04ff5..2ae57c3577 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
@@ -30,7 +30,7 @@ import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.feature.VectorSlicer;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.*;
@@ -55,7 +55,8 @@ public class JavaVectorSlicerExample {
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
));
- DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
+ Dataset<Row> dataset =
+ jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
@@ -63,7 +64,7 @@ public class JavaVectorSlicerExample {
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})
- DataFrame output = vectorSlicer.transform(dataset);
+ Dataset<Row> output = vectorSlicer.transform(dataset);
System.out.println(output.select("userFeatures", "features").first());
// $example off$
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java
index a4a05af7c6..2dce8c2168 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java
@@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.Word2Vec;
import org.apache.spark.ml.feature.Word2VecModel;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -49,7 +49,7 @@ public class JavaWord2VecExample {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
+ Dataset<Row> documentDF = sqlContext.createDataFrame(jrdd, schema);
// Learn a mapping from words to Vectors.
Word2Vec word2Vec = new Word2Vec()
@@ -58,8 +58,8 @@ public class JavaWord2VecExample {
.setVectorSize(3)
.setMinCount(0);
Word2VecModel model = word2Vec.fit(documentDF);
- DataFrame result = model.transform(documentDF);
- for (Row r : result.select("result").take(3)) {
+ Dataset<Row> result = model.transform(documentDF);
+ for (Row r : result.select("result").takeRows(3)) {
System.out.println(r);
}
// $example off$
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index afee279ec3..354a5306ed 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -74,11 +74,12 @@ public class JavaSparkSQL {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class);
+ Dataset<Row> schemaPeople = sqlContext.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
- DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ Dataset<Row> teenagers =
+ sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
@@ -99,11 +100,11 @@ public class JavaSparkSQL {
// Read in the parquet file created above.
// Parquet files are self-describing so the schema is preserved.
// The result of loading a parquet file is also a DataFrame.
- DataFrame parquetFile = sqlContext.read().parquet("people.parquet");
+ Dataset<Row> parquetFile = sqlContext.read().parquet("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
- DataFrame teenagers2 =
+ Dataset<Row> teenagers2 =
sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
teenagerNames = teenagers2.toJavaRDD().map(new Function<Row, String>() {
@Override
@@ -120,7 +121,7 @@ public class JavaSparkSQL {
// The path can be either a single text file or a directory storing text files.
String path = "examples/src/main/resources/people.json";
// Create a DataFrame from the file(s) pointed by path
- DataFrame peopleFromJsonFile = sqlContext.read().json(path);
+ Dataset<Row> peopleFromJsonFile = sqlContext.read().json(path);
// Because the schema of a JSON dataset is automatically inferred, to write queries,
// it is better to take a look at what is the schema.
@@ -134,7 +135,8 @@ public class JavaSparkSQL {
peopleFromJsonFile.registerTempTable("people");
// SQL statements can be run by using the sql methods provided by sqlContext.
- DataFrame teenagers3 = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+ Dataset<Row> teenagers3 =
+ sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
// The results of SQL queries are DataFrame and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
@@ -151,7 +153,7 @@ public class JavaSparkSQL {
List<String> jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD<String> anotherPeopleRDD = ctx.parallelize(jsonData);
- DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd());
+ Dataset<Row> peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd());
// Take a look at the schema of this new DataFrame.
peopleFromJsonRDD.printSchema();
@@ -164,7 +166,7 @@ public class JavaSparkSQL {
peopleFromJsonRDD.registerTempTable("people2");
- DataFrame peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2");
+ Dataset<Row> peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2");
List<String> nameAndCity = peopleWithCity.toJavaRDD().map(new Function<Row, String>() {
@Override
public String call(Row row) {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
index f0228f5e63..4b9d9efc85 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java
@@ -27,8 +27,9 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.VoidFunction2;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.Time;
@@ -92,13 +93,13 @@ public final class JavaSqlNetworkWordCount {
return record;
}
});
- DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class);
+ Dataset<Row> wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class);
// Register as table
wordsDataFrame.registerTempTable("words");
// Do word count on table using SQL and print it
- DataFrame wordCountsDataFrame =
+ Dataset<Row> wordCountsDataFrame =
sqlContext.sql("select word, count(*) as total from words group by word");
System.out.println("========= " + time + "=========");
wordCountsDataFrame.show();
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 0a8c9e5954..60a4a1d2ea 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.ml;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -37,7 +38,7 @@ public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {
@@ -65,7 +66,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 40b9c35adc..0d923dfeff 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,6 +21,8 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeClassifier dt = new DecisionTreeClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 59b6fba7a9..f470f4ada6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
GBTClassifier rf = new GBTClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fd22eb6dca..536f0dc58f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;
@@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
- DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
- DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
@@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
- DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
- for (Row row: trans1.collect()) {
+ Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collectAsList()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
Assert.assertEquals(raw.size(), 2);
@@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
- DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
- for (Row row: trans2.collect()) {
+ Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collectAsList()) {
double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index ec6b4bf3c0..d499d363f1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -28,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
@Test
public void testMLPC() {
- DataFrame dataFrame = sqlContext.createDataFrame(
+ Dataset<Row> dataFrame = sqlContext.createDataFrame(
jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
@@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
.setSeed(11L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
- DataFrame result = model.transform(dataFrame);
- Row[] predictionAndLabels = result.select("prediction", "label").collect();
+ Dataset<Row> result = model.transform(dataFrame);
+ List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 07936eb79b..45101f286c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable {
jsc = null;
}
- public void validatePrediction(DataFrame predictionAndLabels) {
- for (Row r : predictionAndLabels.collect()) {
+ public void validatePrediction(Dataset<Row> predictionAndLabels) {
+ for (Row r : predictionAndLabels.collectAsList()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assertEquals(label, prediction, 1E-5);
@@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
- DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label");
validatePrediction(predictionAndLabels);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index cbabafe1b5..d493a7fcec 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
+import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable {
Assert.assertEquals(ova.getLabelCol() , "label");
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
- DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 5485fcbf01..9a63cef2a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestClassifierSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
RandomForestClassifier rf = new RandomForestClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index cc5a4ef4c2..a3fcdb54ee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
private transient JavaSparkContext sc;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient SQLContext sql;
@Before
@@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector[] centers = model.clusterCenters();
assertEquals(k, centers.length);
- DataFrame transformed = model.transform(dataset);
+ Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index d707bdee99..77e3a489a9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -25,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaBucketizerSuite {
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
@@ -70,7 +71,7 @@ public class JavaBucketizerSuite {
.setOutputCol("result")
.setSplits(splits);
- Row[] result = bucketizer.transform(dataset).select("result").collect();
+ List<Row> result = bucketizer.transform(dataset).select("result").collectAsList();
for (Row r : result) {
double index = r.getDouble(0);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 63e5c93798..ed1ad4c3a3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
import org.junit.After;
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -56,7 +57,7 @@ public class JavaDCTSuite {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
@@ -69,8 +70,8 @@ public class JavaDCTSuite {
.setInputCol("vec")
.setOutputCol("resultVec");
- Row[] result = dct.transform(dataset).select("resultVec").collect();
- Vector resultVec = result[0].getAs("resultVec");
+ List<Row> result = dct.transform(dataset).select("resultVec").collectAsList();
+ Vector resultVec = result.get(0).getAs("resultVec");
Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 5932017f8f..6e2cc7e887 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,21 +65,21 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceData = jsql.createDataFrame(data, schema);
+ Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
- DataFrame wordsData = tokenizer.transform(sentenceData);
+ Dataset<Row> wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurizedData = hashingTF.transform(wordsData);
+ Dataset<Row> featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
- DataFrame rescaledData = idfModel.transform(featurizedData);
- for (Row r : rescaledData.select("features", "label").take(3)) {
+ Dataset<Row> rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").takeAsList(3)) {
Vector features = r.getAs(0);
Assert.assertEquals(features.size(), numFeatures);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index e17d549c50..5bbd9634b2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaNormalizerSuite {
@@ -53,17 +54,17 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
- DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
+ Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");
// Normalize each Vector using $L^2$ norm.
- DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
+ Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
l2NormData.count();
// Normalize each Vector using $L^\infty$ norm.
- DataFrame lInfNormData =
+ Dataset<Row> lInfNormData =
normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
lInfNormData.count();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index e8f329f9cf..1389d17e7e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable {
}
);
- DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index e22d117032..6a8bb64801 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite {
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
- Row[] pairs = polyExpansion.transform(dataset)
+ List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index ed74363f59..3f6fc333e4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaStandardScalerSuite {
@@ -53,7 +54,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
- DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
@@ -65,7 +66,7 @@ public class JavaStandardScalerSuite {
StandardScalerModel scalerModel = scaler.fit(dataFrame);
// Normalize each feature to have unit standard deviation.
- DataFrame scaledData = scalerModel.transform(dataFrame);
+ Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.count();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 139d1d005a..5812037dee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -25,7 +25,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite {
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 153a08a4cd..431779cd2e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -26,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -58,16 +58,16 @@ public class JavaStringIndexerSuite {
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
- DataFrame dataset = sqlContext.createDataFrame(data, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
- DataFrame output = indexer.fit(dataset).transform(dataset);
+ Dataset<Row> output = indexer.fit(dataset).transform(dataset);
- Assert.assertArrayEquals(
- new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) },
- output.orderBy("id").select("id", "labelIndex").collect());
+ Assert.assertEquals(
+ Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)),
+ output.orderBy("id").select("id", "labelIndex").collectAsList());
}
/** An alias for RowFactory.create. */
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index c407d98f1b..83d16cbd0e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -26,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -61,11 +62,11 @@ public class JavaTokenizerSuite {
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
- DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+ Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
- Row[] pairs = myRegExTokenizer.transform(dataset)
+ List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index f8ba84ef77..e45e198043 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
- DataFrame output = assembler.transform(dataset);
+ Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index bfcca62fa1..fec6cac8be 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -30,7 +30,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 4.0))
);
SQLContext sqlContext = new SQLContext(sc);
- DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")
@@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable {
Assert.assertEquals(model.numFeatures(), 2);
Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps();
Assert.assertEquals(categoryMaps.size(), 1);
- DataFrame indexedData = model.transform(data);
+ Dataset<Row> indexedData = model.transform(data);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index 786c11c412..b87605ebfd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite {
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
);
- DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ Dataset<Row> dataset =
+ jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
- DataFrame output = vectorSlicer.transform(dataset);
+ Dataset<Row> output = vectorSlicer.transform(dataset);
- for (Row r : output.select("userFeatures", "features").take(2)) {
+ for (Row r : output.select("userFeatures", "features").takeRows(2)) {
Vector features = r.getAs(1);
Assert.assertEquals(features.size(), 2);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index b292b1b06d..7517b70cc9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame documentDF = sqlContext.createDataFrame(
+ Dataset<Row> documentDF = sqlContext.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@@ -66,9 +66,9 @@ public class JavaWord2VecSuite {
.setVectorSize(3)
.setMinCount(0);
Word2VecModel model = word2Vec.fit(documentDF);
- DataFrame result = model.transform(documentDF);
+ Dataset<Row> result = model.transform(documentDF);
- for (Row r: result.select("result").collect()) {
+ for (Row r: result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index d5c9d120c5..a1575300a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 38d15dc2b7..9477e8d2bf 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
GBTRegressor rf = new GBTRegressor()
.setMaxDepth(2)
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 4fb0b0d109..9f817515eb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 31be8880c2..a90535d11a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestRegressorSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
RandomForestRegressor rf = new RandomForestRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 2976b38e45..b8ddf907d0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -31,7 +31,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
@@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite {
@Test
public void verifyLibSVMDF() {
- DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 08eeca53f0..24b0097454 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 9fda1b1682..6bc2b1e646 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -101,7 +101,7 @@ def _java2py(sc, r, encoding="bytes"):
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc)
- if clsName == 'DataFrame':
+ if clsName == 'Dataset':
return DataFrame(r, SQLContext.getOrCreate(sc))
if clsName in _picklable_classes:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 97f28fad62..d2003fd689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
// TODO: don't swallow original stack trace if it exists
@@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi
class AnalysisException protected[sql] (
val message: String,
val line: Option[Int] = None,
- val startPosition: Option[Int] = None)
+ val startPosition: Option[Int] = None,
+ val plan: Option[LogicalPlan] = None)
extends Exception with Serializable {
def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d8f755a39c..902644e735 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -50,7 +50,9 @@ object RowEncoder {
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => inputObject
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
+
+ case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -137,6 +139,7 @@ object RowEncoder {
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
+ case CalendarIntervalType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -150,19 +153,23 @@ object RowEncoder {
private def constructorFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
- val field = BoundReference(i, f.dataType, f.nullable)
+ val dt = f.dataType match {
+ case p: PythonUserDefinedType => p.sqlType
+ case other => other
+ }
+ val field = BoundReference(i, dt, f.nullable)
If(
IsNull(field),
- Literal.create(null, externalDataTypeFor(f.dataType)),
+ Literal.create(null, externalDataTypeFor(dt)),
constructorFor(field)
)
}
- CreateExternalRow(fields)
+ CreateExternalRow(fields, schema)
}
private def constructorFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => input
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => input
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -216,7 +223,7 @@ object RowEncoder {
"toScalaMap",
keyData :: valueData :: Nil)
- case StructType(fields) =>
+ case schema @ StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
@@ -225,6 +232,6 @@ object RowEncoder {
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
- CreateExternalRow(convertedFields))
+ CreateExternalRow(convertedFields, schema))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 75ecbaa453..b95c5dd892 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -388,6 +388,8 @@ case class MapObjects private(
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+ case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
+ case DateType => (i: String) => s".getInt($i)"
}
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
@@ -485,7 +487,9 @@ case class MapObjects private(
*
* @param children A list of expression to use as content of the external row.
*/
-case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
+case class CreateExternalRow(children: Seq[Expression], schema: StructType)
+ extends Expression with NonSQLExpression {
+
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
@@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
- val rowClass = classOf[GenericRow].getName
+ val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
+ val schemaField = ctx.addReferenceObj("schema", schema)
s"""
boolean ${ev.isNull} = false;
final Object[] $values = new Object[${children.size}];
@@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
}
"""
}.mkString("\n") +
- s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+ s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 24f61992d4..17a91975f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
+import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -26,30 +27,38 @@ import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable,
- QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ val qe = sqlContext.executePlan(logicalPlan)
+ qe.assertAnalyzed()
+ new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema))
+ }
+}
+
+private[sql] object Dataset {
+ def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = {
+ new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]])
}
}
@@ -112,28 +121,19 @@ private[sql] object DataFrame {
* @since 1.3.0
*/
@Experimental
-class DataFrame private[sql](
+class Dataset[T] private[sql](
@transient override val sqlContext: SQLContext,
- @DeveloperApi @transient override val queryExecution: QueryExecution)
+ @DeveloperApi @transient override val queryExecution: QueryExecution,
+ encoder: Encoder[T])
extends Queryable with Serializable {
+ queryExecution.assertAnalyzed()
+
// Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.
- /**
- * A constructor that automatically analyzes the logical plan.
- *
- * This reports error eagerly as the [[DataFrame]] is constructed, unless
- * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
- */
- def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
- this(sqlContext, {
- val qe = sqlContext.executePlan(logicalPlan)
- if (sqlContext.conf.dataFrameEagerAnalysis) {
- qe.assertAnalyzed() // This should force analysis and throw errors if there are any
- }
- qe
- })
+ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
+ this(sqlContext, sqlContext.executePlan(logicalPlan), encoder)
}
@transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
@@ -147,6 +147,26 @@ class DataFrame private[sql](
queryExecution.analyzed
}
+ /**
+ * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
+ * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+ * same object type (that will be possibly resolved to a different schema).
+ */
+ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder)
+ unresolvedTEncoder.validate(logicalPlan.output)
+
+ /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+ private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
+ unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+ /**
+ * The encoder where the expressions used to construct an object from an input row have been
+ * bound to the ordinals of this [[Dataset]]'s output schema.
+ */
+ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
+
+ private implicit def classTag = unresolvedTEncoder.clsTag
+
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
@@ -173,7 +193,11 @@ class DataFrame private[sql](
// For array values, replace Seq and Array with square brackets
// For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map {
+ case r: Row => r
+ case tuple: Product => Row.fromTuple(tuple)
+ case o => Row(o)
+ }.map { row =>
row.toSeq.map { cell =>
val str = cell match {
case null => "null"
@@ -196,7 +220,7 @@ class DataFrame private[sql](
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = this
+ def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema))
/**
* :: Experimental ::
@@ -206,7 +230,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@Experimental
- def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
+ def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan)
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
@@ -360,7 +384,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.1
*/
- def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
+ def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF())
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
@@ -372,7 +396,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)
+ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/**
* Cartesian join with another [[DataFrame]].
@@ -573,6 +597,62 @@ class DataFrame private[sql](
}
/**
+ * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+ * true.
+ *
+ * This is similar to the relation `join` function with one important difference in the
+ * result schema. Since `joinWith` preserves objects present on either side of the join, the
+ * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ *
+ * This type of join can be useful both for preserving type-safety with the original object
+ * types as well as working with relational data where either side of the join has column
+ * names in common.
+ *
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
+ * @since 1.6.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
+ val left = this.logicalPlan
+ val right = other.logicalPlan
+
+ val joined = sqlContext.executePlan(Join(left, right, joinType =
+ JoinType(joinType), Some(condition.expr)))
+ val leftOutput = joined.analyzed.output.take(left.output.length)
+ val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+
+ val leftData = this.unresolvedTEncoder match {
+ case e if e.flat => Alias(leftOutput.head, "_1")()
+ case _ => Alias(CreateStruct(leftOutput), "_1")()
+ }
+ val rightData = other.unresolvedTEncoder match {
+ case e if e.flat => Alias(rightOutput.head, "_2")()
+ case _ => Alias(CreateStruct(rightOutput), "_2")()
+ }
+
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
+ withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+ Project(
+ leftData :: rightData :: Nil,
+ joined.analyzed)
+ }
+ }
+
+ /**
+ * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
+ * where `condition` evaluates to true.
+ *
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @since 1.6.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ joinWith(other, condition, "inner")
+ }
+
+ /**
* Returns a new [[DataFrame]] with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
@@ -581,7 +661,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = {
+ def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
}
@@ -594,7 +674,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortExprs: Column*): DataFrame = {
+ def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}
@@ -610,7 +690,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def sort(sortCol: String, sortCols: String*): DataFrame = {
+ def sort(sortCol: String, sortCols: String*): Dataset[T] = {
sort((sortCol +: sortCols).map(apply) : _*)
}
@@ -623,7 +703,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def sort(sortExprs: Column*): DataFrame = {
+ def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}
@@ -634,7 +714,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*)
+ def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@@ -643,7 +723,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
+ def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)
/**
* Selects column based on the column name and return it as a [[Column]].
@@ -672,7 +752,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def as(alias: String): DataFrame = withPlan {
+ def as(alias: String): Dataset[T] = withTypedPlan {
SubqueryAlias(alias, logicalPlan)
}
@@ -681,21 +761,21 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def as(alias: Symbol): DataFrame = as(alias.name)
+ def as(alias: Symbol): Dataset[T] = as(alias.name)
/**
* Returns a new [[DataFrame]] with an alias set. Same as `as`.
* @group dfops
* @since 1.6.0
*/
- def alias(alias: String): DataFrame = as(alias)
+ def alias(alias: String): Dataset[T] = as(alias)
/**
* (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`.
* @group dfops
* @since 1.6.0
*/
- def alias(alias: Symbol): DataFrame = as(alias)
+ def alias(alias: Symbol): Dataset[T] = as(alias)
/**
* Selects a set of column based expressions.
@@ -745,6 +825,80 @@ class DataFrame private[sql](
}
/**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
+ *
+ * {{{
+ * val ds = Seq(1, 2, 3).toDS()
+ * val newDS = ds.select(expr("value + 1").as[Int])
+ * }}}
+ * @since 1.6.0
+ */
+ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
+ new Dataset[U1](
+ sqlContext,
+ Project(
+ c1.withInputType(
+ boundTEncoder,
+ logicalPlan.output).named :: Nil,
+ logicalPlan),
+ implicitly[Encoder[U1]])
+ }
+
+ /**
+ * Internal helper function for building typed selects that return tuples. For simplicity and
+ * code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
+ val encoders = columns.map(_.encoder)
+ val namedColumns =
+ columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
+ val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
+
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+ }
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
+
+ /**
+ * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+ * @since 1.6.0
+ */
+ def select[U1, U2, U3, U4, U5](
+ c1: TypedColumn[T, U1],
+ c2: TypedColumn[T, U2],
+ c3: TypedColumn[T, U3],
+ c4: TypedColumn[T, U4],
+ c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
+
+ /**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
@@ -754,7 +908,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def filter(condition: Column): DataFrame = withPlan {
+ def filter(condition: Column): Dataset[T] = withTypedPlan {
Filter(condition.expr, logicalPlan)
}
@@ -766,7 +920,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def filter(conditionExpr: String): DataFrame = {
+ def filter(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
@@ -780,7 +934,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def where(condition: Column): DataFrame = filter(condition)
+ def where(condition: Column): Dataset[T] = filter(condition)
/**
* Filters rows using the given SQL expression.
@@ -790,7 +944,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.5.0
*/
- def where(conditionExpr: String): DataFrame = {
+ def where(conditionExpr: String): Dataset[T] = {
filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
@@ -813,7 +967,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def groupBy(cols: Column*): GroupedData = {
- GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
+ GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType)
}
/**
@@ -836,7 +990,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def rollup(cols: Column*): GroupedData = {
- GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
+ GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType)
}
/**
@@ -858,7 +1012,7 @@ class DataFrame private[sql](
* @since 1.4.0
*/
@scala.annotation.varargs
- def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
+ def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType)
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -883,10 +1037,73 @@ class DataFrame private[sql](
@scala.annotation.varargs
def groupBy(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+ }
+
+ /**
+ * (Scala-specific)
+ * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: (T, T) => T): T = rdd.reduce(func)
+
+ /**
+ * (Java-specific)
+ * Reduces the elements of this Dataset using the specified binary function. The given `func`
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
+
+ /**
+ * (Scala-specific)
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+ * @since 1.6.0
+ */
+ def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = {
+ val inputPlan = logicalPlan
+ val withGroupingKey = AppendColumns(func, inputPlan)
+ val executed = sqlContext.executePlan(withGroupingKey)
+
+ new GroupedDataset(
+ encoderFor[K],
+ encoderFor[T],
+ executed,
+ inputPlan.output,
+ withGroupingKey.newColumns)
+ }
+
+ /**
+ * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def groupByKey(cols: Column*): GroupedDataset[Row, T] = {
+ val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
+ val withKey = Project(withKeyColumns, logicalPlan)
+ val executed = sqlContext.executePlan(withKey)
+
+ val dataAttributes = executed.analyzed.output.dropRight(cols.size)
+ val keyAttributes = executed.analyzed.output.takeRight(cols.size)
+
+ new GroupedDataset(
+ RowEncoder(keyAttributes.toStructType),
+ encoderFor[T],
+ executed,
+ dataAttributes,
+ keyAttributes)
}
/**
+ * (Java-specific)
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+ * @since 1.6.0
+ */
+ def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ groupByKey(func.call(_))(encoder)
+
+ /**
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
* so we can run aggregation on them.
* See [[GroupedData]] for all the available aggregate functions.
@@ -910,7 +1127,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def rollup(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType)
}
/**
@@ -937,7 +1154,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def cube(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
+ GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType)
}
/**
@@ -997,7 +1214,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def limit(n: Int): DataFrame = withPlan {
+ def limit(n: Int): Dataset[T] = withTypedPlan {
Limit(Literal(n), logicalPlan)
}
@@ -1007,19 +1224,21 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def unionAll(other: DataFrame): DataFrame = withPlan {
+ def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan {
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
CombineUnions(Union(logicalPlan, other.logicalPlan))
}
+ def union(other: Dataset[T]): Dataset[T] = unionAll(other)
+
/**
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
* This is equivalent to `INTERSECT` in SQL.
* @group dfops
* @since 1.3.0
*/
- def intersect(other: DataFrame): DataFrame = withPlan {
+ def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan {
Intersect(logicalPlan, other.logicalPlan)
}
@@ -1029,10 +1248,12 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def except(other: DataFrame): DataFrame = withPlan {
+ def except(other: Dataset[T]): Dataset[T] = withTypedPlan {
Except(logicalPlan, other.logicalPlan)
}
+ def subtract(other: Dataset[T]): Dataset[T] = except(other)
+
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows.
*
@@ -1042,7 +1263,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan {
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
}
@@ -1054,7 +1275,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
sample(withReplacement, fraction, Utils.random.nextLong)
}
@@ -1066,7 +1287,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
@@ -1075,7 +1296,8 @@ class DataFrame private[sql](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
+ new Dataset[T](
+ sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
}.toArray
}
@@ -1086,7 +1308,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+ def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
randomSplit(weights, Utils.random.nextLong)
}
@@ -1097,7 +1319,7 @@ class DataFrame private[sql](
* @param seed Seed for sampling.
* @group dfops
*/
- private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+ private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
randomSplit(weights.toArray, seed)
}
@@ -1238,7 +1460,7 @@ class DataFrame private[sql](
}
select(columns : _*)
} else {
- this
+ toDF()
}
}
@@ -1264,7 +1486,7 @@ class DataFrame private[sql](
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
- this
+ toDF()
} else {
this.select(remainingCols: _*)
}
@@ -1297,7 +1519,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(): DataFrame = dropDuplicates(this.columns)
+ def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
/**
* (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only
@@ -1306,7 +1528,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan {
+ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val groupCols = colNames.map(resolve)
val groupColExprIds = groupCols.map(_.exprId)
val aggCols = logicalPlan.output.map { attr =>
@@ -1326,7 +1548,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.4.0
*/
- def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq)
+ def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
/**
* Computes statistics for numeric columns, including count, mean, stddev, min, and max.
@@ -1396,7 +1618,7 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df =>
+ def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df =>
df.collect(needCallback = false)
}
@@ -1405,14 +1627,14 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def head(): Row = head(1).head
+ def head(): T = head(1).head
/**
* Returns the first row. Alias for head().
* @group action
* @since 1.3.0
*/
- def first(): Row = head()
+ def first(): T = head()
/**
* Concise syntax for chaining custom transformations.
@@ -1425,27 +1647,113 @@ class DataFrame private[sql](
* }}}
* @since 1.6.0
*/
- def transform[U](t: DataFrame => DataFrame): DataFrame = t(this)
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ map(t => func.call(t))(encoder)
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+ * @since 1.6.0
+ */
+ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+ new Dataset[U](
+ sqlContext,
+ MapPartitions[T, U](func, logicalPlan),
+ implicitly[Encoder[U]])
+ }
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+ * @since 1.6.0
+ */
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
+ mapPartitions(func)(encoder)
+ }
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
+ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+ mapPartitions(_.flatMap(func))
+
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
+ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (T) => Iterator[U] = x => f.call(x).asScala
+ flatMap(func)(encoder)
+ }
/**
* Applies a function `f` to all rows.
* @group rdd
* @since 1.3.0
*/
- def foreach(f: Row => Unit): Unit = withNewExecutionId {
+ def foreach(f: T => Unit): Unit = withNewExecutionId {
rdd.foreach(f)
}
/**
+ * (Java-specific)
+ * Runs `func` on each element of this [[Dataset]].
+ * @since 1.6.0
+ */
+ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
+
+ /**
* Applies a function f to each partition of this [[DataFrame]].
* @group rdd
* @since 1.3.0
*/
- def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
+ def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId {
rdd.foreachPartition(f)
}
/**
+ * (Java-specific)
+ * Runs `func` on each partition of this [[Dataset]].
+ * @since 1.6.0
+ */
+ def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
+ foreachPartition(it => func.call(it.asJava))
+
+ /**
* Returns the first `n` rows in the [[DataFrame]].
*
* Running take requires moving data into the application's driver process, and doing so with
@@ -1454,7 +1762,11 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def take(n: Int): Array[Row] = head(n)
+ def take(n: Int): Array[T] = head(n)
+
+ def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds =>
+ ds.collectRows(needCallback = false)
+ }
/**
* Returns the first `n` rows in the [[DataFrame]] as a list.
@@ -1465,7 +1777,7 @@ class DataFrame private[sql](
* @group action
* @since 1.6.0
*/
- def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
+ def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*)
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
@@ -1478,7 +1790,9 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collect(): Array[Row] = collect(needCallback = true)
+ def collect(): Array[T] = collect(needCallback = true)
+
+ def collectRows(): Array[Row] = collectRows(needCallback = true)
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
@@ -1489,19 +1803,32 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
- java.util.Arrays.asList(rdd.collect() : _*)
+ val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ java.util.Arrays.asList(values : _*)
}
}
- private def collect(needCallback: Boolean): Array[Row] = {
+ private def collect(needCallback: Boolean): Array[T] = {
+ def execute(): Array[T] = withNewExecutionId {
+ queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+ }
+
+ if (needCallback) {
+ withCallback("collect", toDF())(_ => execute())
+ } else {
+ execute()
+ }
+ }
+
+ private def collectRows(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
}
if (needCallback) {
- withCallback("collect", this)(_ => execute())
+ withCallback("collect", toDF())(_ => execute())
} else {
execute()
}
@@ -1521,7 +1848,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def repartition(numPartitions: Int): DataFrame = withPlan {
+ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = true, logicalPlan)
}
@@ -1535,7 +1862,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan {
+ def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
}
@@ -1549,7 +1876,7 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
- def repartition(partitionExprs: Column*): DataFrame = withPlan {
+ def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
}
@@ -1561,7 +1888,7 @@ class DataFrame private[sql](
* @group rdd
* @since 1.4.0
*/
- def coalesce(numPartitions: Int): DataFrame = withPlan {
+ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
}
@@ -1571,7 +1898,7 @@ class DataFrame private[sql](
* @group dfops
* @since 1.3.0
*/
- def distinct(): DataFrame = dropDuplicates()
+ def distinct(): Dataset[T] = dropDuplicates()
/**
* Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
@@ -1632,12 +1959,11 @@ class DataFrame private[sql](
* @group rdd
* @since 1.3.0
*/
- lazy val rdd: RDD[Row] = {
+ lazy val rdd: RDD[T] = {
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
queryExecution.toRdd.mapPartitions { rows =>
- val converter = CatalystTypeConverters.createToScalaConverter(schema)
- rows.map(converter(_).asInstanceOf[Row])
+ rows.map(boundTEncoder.fromRow)
}
}
@@ -1646,14 +1972,14 @@ class DataFrame private[sql](
* @group rdd
* @since 1.3.0
*/
- def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
+ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD()
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
* @group rdd
* @since 1.3.0
*/
- def javaRDD: JavaRDD[Row] = toJavaRDD
+ def javaRDD: JavaRDD[T] = toJavaRDD
/**
* Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this
@@ -1663,7 +1989,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def registerTempTable(tableName: String): Unit = {
- sqlContext.registerDataFrameAsTable(this, tableName)
+ sqlContext.registerDataFrameAsTable(toDF(), tableName)
}
/**
@@ -1674,7 +2000,7 @@ class DataFrame private[sql](
* @since 1.4.0
*/
@Experimental
- def write: DataFrameWriter = new DataFrameWriter(this)
+ def write: DataFrameWriter = new DataFrameWriter(toDF())
/**
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
@@ -1745,7 +2071,7 @@ class DataFrame private[sql](
* Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
* an execution.
*/
- private[sql] def withNewExecutionId[T](body: => T): T = {
+ private[sql] def withNewExecutionId[U](body: => U): U = {
SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
}
@@ -1753,7 +2079,7 @@ class DataFrame private[sql](
* Wrap a DataFrame action to track the QueryExecution and time cost, then report to the
* user-registered callback functions.
*/
- private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
+ private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = {
try {
df.queryExecution.executedPlan.foreach { plan =>
plan.resetMetrics()
@@ -1770,7 +2096,24 @@ class DataFrame private[sql](
}
}
- private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
+ private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = {
+ try {
+ ds.queryExecution.executedPlan.foreach { plan =>
+ plan.resetMetrics()
+ }
+ val start = System.nanoTime()
+ val result = action(ds)
+ val end = System.nanoTime()
+ sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start)
+ result
+ } catch {
+ case e: Exception =>
+ sqlContext.listenerManager.onFailure(name, ds.queryExecution, e)
+ throw e
+ }
+ }
+
+ private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
@@ -1779,14 +2122,23 @@ class DataFrame private[sql](
SortOrder(expr, Ascending)
}
}
- withPlan {
+ withTypedPlan {
Sort(sortOrder, global = global, logicalPlan)
}
}
/** A convenient function to wrap a logical plan and produce a DataFrame. */
@inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ DataFrame(sqlContext, logicalPlan)
+ }
+
+ /** A convenient function to wrap a logical plan and produce a DataFrame. */
+ @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
+ new Dataset[T](sqlContext, logicalPlan, encoder)
}
+ private[sql] def withTypedPlan[R](
+ other: Dataset[_], encoder: Encoder[R])(
+ f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
+ new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 509b29956f..822702429d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -345,7 +345,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
InferSchema.infer(jsonRDD, sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
}
- new DataFrame(
+ DataFrame(
sqlContext,
LogicalRDD(
schema.toAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
deleted file mode 100644
index daddf6e0c5..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ /dev/null
@@ -1,794 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.function._
-import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{Queryable, QueryExecution}
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-/**
- * :: Experimental ::
- * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel
- * using functional or relational operations.
- *
- * A [[Dataset]] differs from an [[RDD]] in the following ways:
- * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored
- * in the encoded form. This representation allows for additional logical operations and
- * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to
- * an object.
- * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be
- * used to serialize the object into a binary format. Encoders are also capable of mapping the
- * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime
- * reflection based serialization. Operations that change the type of object stored in the
- * dataset also need an encoder for the new type.
- *
- * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific
- * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into
- * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed
- * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`.
- *
- * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However,
- * making this change to the class hierarchy would break the function signatures for the existing
- * functional operations (map, flatMap, etc). As such, this class should be considered a preview
- * of the final API. Changes will be made to the interface after Spark 1.6.
- *
- * @since 1.6.0
- */
-@Experimental
-class Dataset[T] private[sql](
- @transient override val sqlContext: SQLContext,
- @transient override val queryExecution: QueryExecution,
- tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
-
- /**
- * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
- * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
- * same object type (that will be possibly resolved to a different schema).
- */
- private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
- unresolvedTEncoder.validate(logicalPlan.output)
-
- /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
- private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
-
- /**
- * The encoder where the expressions used to construct an object from an input row have been
- * bound to the ordinals of this [[Dataset]]'s output schema.
- */
- private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
-
- private implicit def classTag = unresolvedTEncoder.clsTag
-
- private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
- this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
-
- /**
- * Returns the schema of the encoded form of the objects in this [[Dataset]].
- * @since 1.6.0
- */
- override def schema: StructType = resolvedTEncoder.schema
-
- /**
- * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
- * @since 1.6.0
- */
- override def printSchema(): Unit = toDF().printSchema()
-
- /**
- * Prints the plans (logical and physical) to the console for debugging purposes.
- * @since 1.6.0
- */
- override def explain(extended: Boolean): Unit = toDF().explain(extended)
-
- /**
- * Prints the physical plan to the console for debugging purposes.
- * @since 1.6.0
- */
- override def explain(): Unit = toDF().explain()
-
- /* ************* *
- * Conversions *
- * ************* */
-
- /**
- * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
- * method used to map columns depend on the type of `U`:
- * - When `U` is a class, fields for the class will be mapped to columns of the same name
- * (case sensitivity is determined by `spark.sql.caseSensitive`)
- * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will
- * be assigned to `_1`).
- * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the
- * [[DataFrame]] will be used.
- *
- * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select`
- * along with `alias` or `as` to rearrange or rename as required.
- * @since 1.6.0
- */
- def as[U : Encoder]: Dataset[U] = {
- new Dataset(sqlContext, queryExecution, encoderFor[U])
- }
-
- /**
- * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
- * the same name after two Datasets have been joined.
- * @since 1.6.0
- */
- def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))
-
- /**
- * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
- * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
- * objects that allow fields to be accessed by ordinal or name.
- */
- // This is declared with parentheses to prevent the Scala compiler from treating
- // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
-
- /**
- * Returns this [[Dataset]].
- * @since 1.6.0
- */
- // This is declared with parentheses to prevent the Scala compiler from treating
- // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
- def toDS(): Dataset[T] = this
-
- /**
- * Converts this [[Dataset]] to an [[RDD]].
- * @since 1.6.0
- */
- def rdd: RDD[T] = {
- queryExecution.toRdd.mapPartitions { iter =>
- iter.map(boundTEncoder.fromRow)
- }
- }
-
- /**
- * Returns the number of elements in the [[Dataset]].
- * @since 1.6.0
- */
- def count(): Long = toDF().count()
-
- /**
- * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right. For example:
- * {{{
- * year month AVG('Adj Close) MAX('Adj Close)
- * 1980 12 0.503218 0.595103
- * 1981 01 0.523289 0.570307
- * 1982 02 0.436504 0.475256
- * 1983 03 0.410516 0.442194
- * 1984 04 0.450090 0.483521
- * }}}
- * @param numRows Number of rows to show
- *
- * @since 1.6.0
- */
- def show(numRows: Int): Unit = show(numRows, truncate = true)
-
- /**
- * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right.
- *
- * @since 1.6.0
- */
- def show(): Unit = show(20)
-
- /**
- * Displays the top 20 rows of [[Dataset]] in a tabular form.
- *
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
- *
- * @since 1.6.0
- */
- def show(truncate: Boolean): Unit = show(20, truncate)
-
- /**
- * Displays the [[Dataset]] in a tabular form. For example:
- * {{{
- * year month AVG('Adj Close) MAX('Adj Close)
- * 1980 12 0.503218 0.595103
- * 1981 01 0.523289 0.570307
- * 1982 02 0.436504 0.475256
- * 1983 03 0.410516 0.442194
- * 1984 04 0.450090 0.483521
- * }}}
- * @param numRows Number of rows to show
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
- *
- * @since 1.6.0
- */
- // scalastyle:off println
- def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate))
- // scalastyle:on println
-
- /**
- * Compose the string representing rows for output
- * @param _numRows Number of rows to show
- * @param truncate Whether truncate long strings and align cells right
- */
- override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
- val numRows = _numRows.max(0)
- val takeResult = take(numRows + 1)
- val hasMoreData = takeResult.length > numRows
- val data = takeResult.take(numRows)
-
- // For array values, replace Seq and Array with square brackets
- // For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map {
- case r: Row => r
- case tuple: Product => Row.fromTuple(tuple)
- case o => Row(o)
- } map { row =>
- row.toSeq.map { cell =>
- val str = cell match {
- case null => "null"
- case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
- case array: Array[_] => array.mkString("[", ", ", "]")
- case seq: Seq[_] => seq.mkString("[", ", ", "]")
- case _ => cell.toString
- }
- if (truncate && str.length > 20) str.substring(0, 17) + "..." else str
- }: Seq[String]
- })
-
- formatString ( rows, numRows, hasMoreData, truncate )
- }
-
- /**
- * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
- * @since 1.6.0
- */
- def repartition(numPartitions: Int): Dataset[T] = withPlan {
- Repartition(numPartitions, shuffle = true, _)
- }
-
- /**
- * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
- * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
- * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
- * the 100 new partitions will claim 10 of the current partitions.
- * @since 1.6.0
- */
- def coalesce(numPartitions: Int): Dataset[T] = withPlan {
- Repartition(numPartitions, shuffle = false, _)
- }
-
- /* *********************** *
- * Functional Operations *
- * *********************** */
-
- /**
- * Concise syntax for chaining custom transformations.
- * {{{
- * def featurize(ds: Dataset[T]) = ...
- *
- * dataset
- * .transform(featurize)
- * .transform(...)
- * }}}
- * @since 1.6.0
- */
- def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
- * @since 1.6.0
- */
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
- * @since 1.6.0
- */
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
- * @since 1.6.0
- */
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
- * @since 1.6.0
- */
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
- * @since 1.6.0
- */
- def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
- new Dataset[U](
- sqlContext,
- MapPartitions[T, U](func, logicalPlan))
- }
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
- * @since 1.6.0
- */
- def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
- mapPartitions(func)(encoder)
- }
-
- /**
- * (Scala-specific)
- * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
- * and then flattening the results.
- * @since 1.6.0
- */
- def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
- mapPartitions(_.flatMap(func))
-
- /**
- * (Java-specific)
- * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
- * and then flattening the results.
- * @since 1.6.0
- */
- def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- val func: (T) => Iterator[U] = x => f.call(x).asScala
- flatMap(func)(encoder)
- }
-
- /* ************** *
- * Side effects *
- * ************** */
-
- /**
- * (Scala-specific)
- * Runs `func` on each element of this [[Dataset]].
- * @since 1.6.0
- */
- def foreach(func: T => Unit): Unit = rdd.foreach(func)
-
- /**
- * (Java-specific)
- * Runs `func` on each element of this [[Dataset]].
- * @since 1.6.0
- */
- def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
-
- /**
- * (Scala-specific)
- * Runs `func` on each partition of this [[Dataset]].
- * @since 1.6.0
- */
- def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
-
- /**
- * (Java-specific)
- * Runs `func` on each partition of this [[Dataset]].
- * @since 1.6.0
- */
- def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
- foreachPartition(it => func.call(it.asJava))
-
- /* ************* *
- * Aggregation *
- * ************* */
-
- /**
- * (Scala-specific)
- * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
- * @since 1.6.0
- */
- def reduce(func: (T, T) => T): T = rdd.reduce(func)
-
- /**
- * (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
- * @since 1.6.0
- */
- def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
-
- /**
- * (Scala-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
- * @since 1.6.0
- */
- def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
- val inputPlan = logicalPlan
- val withGroupingKey = AppendColumns(func, inputPlan)
- val executed = sqlContext.executePlan(withGroupingKey)
-
- new GroupedDataset(
- encoderFor[K],
- encoderFor[T],
- executed,
- inputPlan.output,
- withGroupingKey.newColumns)
- }
-
- /**
- * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
- * @since 1.6.0
- */
- @scala.annotation.varargs
- def groupBy(cols: Column*): GroupedDataset[Row, T] = {
- val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
- val withKey = Project(withKeyColumns, logicalPlan)
- val executed = sqlContext.executePlan(withKey)
-
- val dataAttributes = executed.analyzed.output.dropRight(cols.size)
- val keyAttributes = executed.analyzed.output.takeRight(cols.size)
-
- new GroupedDataset(
- RowEncoder(keyAttributes.toStructType),
- encoderFor[T],
- executed,
- dataAttributes,
- keyAttributes)
- }
-
- /**
- * (Java-specific)
- * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
- * @since 1.6.0
- */
- def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
- groupBy(func.call(_))(encoder)
-
- /* ****************** *
- * Typed Relational *
- * ****************** */
-
- /**
- * Returns a new [[DataFrame]] by selecting a set of column based expressions.
- * {{{
- * df.select($"colA", $"colB" + 1)
- * }}}
- * @since 1.6.0
- */
- // Copied from Dataframe to make sure we don't have invalid overloads.
- @scala.annotation.varargs
- protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
- *
- * {{{
- * val ds = Seq(1, 2, 3).toDS()
- * val newDS = ds.select(expr("value + 1").as[Int])
- * }}}
- * @since 1.6.0
- */
- def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](
- sqlContext,
- Project(
- c1.withInputType(
- boundTEncoder,
- logicalPlan.output).named :: Nil,
- logicalPlan))
- }
-
- /**
- * Internal helper function for building typed selects that return tuples. For simplicity and
- * code reuse, we do this without the help of the type system and then use helper functions
- * that cast appropriately for the user facing interface.
- */
- protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(_.encoder)
- val namedColumns =
- columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
- val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
-
- new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
- }
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
- selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
- selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3, U4](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3],
- c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
- selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
-
- /**
- * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
- * @since 1.6.0
- */
- def select[U1, U2, U3, U4, U5](
- c1: TypedColumn[T, U1],
- c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3],
- c4: TypedColumn[T, U4],
- c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
- selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
-
- /**
- * Returns a new [[Dataset]] by sampling a fraction of records.
- * @since 1.6.0
- */
- def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
- withPlan(Sample(0.0, fraction, withReplacement, seed, _)())
-
- /**
- * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.
- * @since 1.6.0
- */
- def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = {
- sample(withReplacement, fraction, Utils.random.nextLong)
- }
-
- /* **************** *
- * Set operations *
- * **************** */
-
- /**
- * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]].
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def distinct: Dataset[T] = withPlan(Distinct)
-
- /**
- * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also
- * present in `other`.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
-
- /**
- * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
- * combined.
- *
- * Note that, this function is not a typical set union operation, in that it does not eliminate
- * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
- * @since 1.6.0
- */
- def union(other: Dataset[T]): Dataset[T] = withPlan[T](other) { (left, right) =>
- // This breaks caching, but it's usually ok because it addresses a very specific use case:
- // using union to union many files or partitions.
- CombineUnions(Union(left, right))
- }
-
- /**
- * Returns a new [[Dataset]] where any elements present in `other` have been removed.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- * @since 1.6.0
- */
- def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
-
- /* ****** *
- * Joins *
- * ****** */
-
- /**
- * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
- * true.
- *
- * This is similar to the relation `join` function with one important difference in the
- * result schema. Since `joinWith` preserves objects present on either side of the join, the
- * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
- *
- * This type of join can be useful both for preserving type-safety with the original object
- * types as well as working with relational data where either side of the join has column
- * names in common.
- *
- * @param other Right side of the join.
- * @param condition Join expression.
- * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
- * @since 1.6.0
- */
- def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
- val left = this.logicalPlan
- val right = other.logicalPlan
-
- val joined = sqlContext.executePlan(Join(left, right, joinType =
- JoinType(joinType), Some(condition.expr)))
- val leftOutput = joined.analyzed.output.take(left.output.length)
- val rightOutput = joined.analyzed.output.takeRight(right.output.length)
-
- val leftData = this.unresolvedTEncoder match {
- case e if e.flat => Alias(leftOutput.head, "_1")()
- case _ => Alias(CreateStruct(leftOutput), "_1")()
- }
- val rightData = other.unresolvedTEncoder match {
- case e if e.flat => Alias(rightOutput.head, "_2")()
- case _ => Alias(CreateStruct(rightOutput), "_2")()
- }
-
- implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withPlan[(T, U)](other) { (left, right) =>
- Project(
- leftData :: rightData :: Nil,
- joined.analyzed)
- }
- }
-
- /**
- * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
- * where `condition` evaluates to true.
- *
- * @param other Right side of the join.
- * @param condition Join expression.
- * @since 1.6.0
- */
- def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
- joinWith(other, condition, "inner")
- }
-
- /* ************************** *
- * Gather to Driver Actions *
- * ************************** */
-
- /**
- * Returns the first element in this [[Dataset]].
- * @since 1.6.0
- */
- def first(): T = take(1).head
-
- /**
- * Returns an array that contains all the elements in this [[Dataset]].
- *
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
- *
- * For Java API, use [[collectAsList]].
- * @since 1.6.0
- */
- def collect(): Array[T] = {
- // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
- // to convert the rows into objects of type T.
- queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
- }
-
- /**
- * Returns an array that contains all the elements in this [[Dataset]].
- *
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
- *
- * For Java API, use [[collectAsList]].
- * @since 1.6.0
- */
- def collectAsList(): java.util.List[T] = collect().toSeq.asJava
-
- /**
- * Returns the first `num` elements of this [[Dataset]] as an array.
- *
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `num` can crash the driver process with OutOfMemoryError.
- * @since 1.6.0
- */
- def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
-
- /**
- * Returns the first `num` elements of this [[Dataset]] as an array.
- *
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `num` can crash the driver process with OutOfMemoryError.
- * @since 1.6.0
- */
- def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
-
- /**
- * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
- * @since 1.6.0
- */
- def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
- this
- }
-
- /**
- * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
- * @since 1.6.0
- */
- def cache(): this.type = persist()
-
- /**
- * Persist this [[Dataset]] with the given storage level.
- * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
- * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
- * `MEMORY_AND_DISK_2`, etc.
- * @group basic
- * @since 1.6.0
- */
- def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
- this
- }
-
- /**
- * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
- * @param blocking Whether to block until all blocks are deleted.
- * @since 1.6.0
- */
- def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
- this
- }
-
- /**
- * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
- * @since 1.6.0
- */
- def unpersist(): this.type = unpersist(blocking = false)
-
- /* ******************** *
- * Internal Functions *
- * ******************** */
-
- private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
-
- private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
-
- private[sql] def withPlan[R : Encoder](
- other: Dataset[_])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index a7258d742a..2a0f77349a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.NumericType
/**
* :: Experimental ::
- * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
+ * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
*
* The main method is the agg function, which has multiple variants. This class also contains
* convenience some first order statistics such as mean, sum for convenience.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index cd8ed472ec..1639cc8db6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -64,7 +64,7 @@ class GroupedDataset[K, V] private[sql](
private def groupedData =
new GroupedData(
- new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
+ DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
/**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
@@ -86,7 +86,7 @@ class GroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def keys: Dataset[K] = {
- new Dataset[K](
+ Dataset[K](
sqlContext,
Distinct(
Project(groupingAttributes, logicalPlan)))
@@ -111,7 +111,7 @@ class GroupedDataset[K, V] private[sql](
* @since 1.6.0
*/
def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
- new Dataset[U](
+ Dataset[U](
sqlContext,
MapGroups(
f,
@@ -308,7 +308,7 @@ class GroupedDataset[K, V] private[sql](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit val uEncoder = other.unresolvedVEncoder
- new Dataset[R](
+ Dataset[R](
sqlContext,
CoGroup(
f,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c742bf2f89..54dbd6bda5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -464,7 +464,7 @@ class SQLContext private[sql](
val encoded = data.map(d => enc.toRow(d).copy())
val plan = new LocalRelation(attributes, encoded)
- new Dataset[T](this, plan)
+ Dataset[T](this, plan)
}
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
@@ -473,7 +473,7 @@ class SQLContext private[sql](
val encoded = data.map(d => enc.toRow(d))
val plan = LogicalRDD(attributes, encoded)(self)
- new Dataset[T](this, plan)
+ Dataset[T](this, plan)
}
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 16c4095db7..e23d5e1261 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -126,6 +126,7 @@ abstract class SQLImplicits {
/**
* Creates a [[Dataset]] from an RDD.
+ *
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 8616fe3170..19ab3ea132 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
@@ -31,7 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
- def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed)
+ def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch {
+ case e: AnalysisException =>
+ throw new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
+ }
lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index e048ee1441..60ec67c8f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -154,7 +154,7 @@ case class DataSource(
}
def dataFrameBuilder(files: Array[String]): DataFrame = {
- new DataFrame(
+ DataFrame(
sqlContext,
LogicalRelation(
DataSource(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index a191759813..0dc34814fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging {
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
- new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
+ DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 26e4eda542..daa065e5cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging {
}
val schema = StructType(StructField(tableName, StringType) +: headerNames)
- new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+ DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index bc7c520930..7d7c51b158 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -211,7 +211,7 @@ class StreamExecution(
// Construct the batch and send it to the sink.
val batchOffset = streamProgress.toCompositeOffset(sources)
- val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
+ val nextBatch = new Batch(batchOffset, DataFrame(sqlContext, newPlan))
sink.addBatch(nextBatch)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 8124df15af..3b764c5558 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -55,11 +55,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def schema: StructType = encoder.schema
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
- new Dataset(sqlContext, logicalPlan)
+ Dataset(sqlContext, logicalPlan)
}
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
- new DataFrame(sqlContext, logicalPlan)
+ DataFrame(sqlContext, logicalPlan)
}
def addData(data: A*): Offset = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 6eea924517..844f3051fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -46,7 +46,6 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
* @tparam I The input type for the aggregation.
* @tparam B The type of the intermediate value of the reduction.
* @tparam O The type of the final output result.
- *
* @since 1.6.0
*/
abstract class Aggregator[-I, B, O] extends Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index bd73a36fd4..97e35bb104 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -42,4 +42,5 @@ package object sql {
@DeveloperApi
type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+ type DataFrame = Dataset[Row]
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 51f987fda9..42af813bc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
+ Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
@@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+ Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
@Override
@@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- DataFrame df1 = sqlContext.read().json(jsonRDD);
+ Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+ Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index ee85626435..47cc74dbc1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -64,13 +64,13 @@ public class JavaDataFrameSuite {
@Test
public void testExecution() {
- DataFrame df = context.table("testData").filter("key = 1");
- Assert.assertEquals(1, df.select("key").collect()[0].get(0));
+ Dataset<Row> df = context.table("testData").filter("key = 1");
+ Assert.assertEquals(1, df.select("key").collectRows()[0].get(0));
}
@Test
public void testCollectAndTake() {
- DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
Assert.assertEquals(3, df.select("key").collectAsList().size());
Assert.assertEquals(2, df.select("key").takeAsList(2).size());
}
@@ -80,7 +80,7 @@ public class JavaDataFrameSuite {
*/
@Test
public void testVarargMethods() {
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.toDF("key1", "value1");
@@ -109,7 +109,7 @@ public class JavaDataFrameSuite {
df.select(coalesce(col("key")));
// Varargs with mathfunctions
- DataFrame df2 = context.table("testData2");
+ Dataset<Row> df2 = context.table("testData2");
df2.select(exp("a"), exp("b"));
df2.select(exp(log("a")));
df2.select(pow("a", "a"), pow("b", 2.0));
@@ -123,7 +123,7 @@ public class JavaDataFrameSuite {
@Ignore
public void testShow() {
// This test case is intended ignored, but to make sure it compiles correctly
- DataFrame df = context.table("testData");
+ Dataset<Row> df = context.table("testData");
df.show();
df.show(1000);
}
@@ -151,7 +151,7 @@ public class JavaDataFrameSuite {
}
}
- void validateDataFrameWithBeans(Bean bean, DataFrame df) {
+ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
StructType schema = df.schema();
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
schema.apply("a"));
@@ -191,7 +191,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromLocalJavaBeans() {
Bean bean = new Bean();
List<Bean> data = Arrays.asList(bean);
- DataFrame df = context.createDataFrame(data, Bean.class);
+ Dataset<Row> df = context.createDataFrame(data, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -199,7 +199,7 @@ public class JavaDataFrameSuite {
public void testCreateDataFrameFromJavaBeans() {
Bean bean = new Bean();
JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- DataFrame df = context.createDataFrame(rdd, Bean.class);
+ Dataset<Row> df = context.createDataFrame(rdd, Bean.class);
validateDataFrameWithBeans(bean, df);
}
@@ -207,8 +207,8 @@ public class JavaDataFrameSuite {
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
- DataFrame df = context.createDataFrame(rows, schema);
- Row[] result = df.collect();
+ Dataset<Row> df = context.createDataFrame(rows, schema);
+ Row[] result = df.collectRows();
Assert.assertEquals(1, result.length);
}
@@ -235,13 +235,13 @@ public class JavaDataFrameSuite {
@Test
public void testCrosstab() {
- DataFrame df = context.table("testData2");
- DataFrame crosstab = df.stat().crosstab("a", "b");
+ Dataset<Row> df = context.table("testData2");
+ Dataset<Row> crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
Assert.assertEquals("2", columnNames[1]);
Assert.assertEquals("1", columnNames[2]);
- Row[] rows = crosstab.collect();
+ Row[] rows = crosstab.collectRows();
Arrays.sort(rows, crosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
@@ -254,31 +254,31 @@ public class JavaDataFrameSuite {
@Test
public void testFrequentItems() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
String[] cols = {"a"};
- DataFrame results = df.stat().freqItems(cols, 0.2);
- Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
+ Dataset<Row> results = df.stat().freqItems(cols, 0.2);
+ Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1));
}
@Test
public void testCorrelation() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6);
}
@Test
public void testCovariance() {
- DataFrame df = context.table("testData2");
+ Dataset<Row> df = context.table("testData2");
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1.0e-6);
}
@Test
public void testSampleBy() {
- DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
- DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
- Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+ Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
+ Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows();
Assert.assertEquals(0, actual[0].getLong(0));
Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8);
Assert.assertEquals(1, actual[1].getLong(0));
@@ -287,10 +287,10 @@ public class JavaDataFrameSuite {
@Test
public void pivot() {
- DataFrame df = context.table("courseSales");
+ Dataset<Row> df = context.table("courseSales");
Row[] actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
- .agg(sum("earnings")).orderBy("year").collect();
+ .agg(sum("earnings")).orderBy("year").collectRows();
Assert.assertEquals(2012, actual[0].getInt(0));
Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
@@ -303,11 +303,11 @@ public class JavaDataFrameSuite {
@Test
public void testGenericLoad() {
- DataFrame df1 = context.read().format("text").load(
+ Dataset<Row> df1 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().format("text").load(
+ Dataset<Row> df2 = context.read().format("text").load(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -315,11 +315,11 @@ public class JavaDataFrameSuite {
@Test
public void testTextLoad() {
- DataFrame df1 = context.read().text(
+ Dataset<Row> df1 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString());
Assert.assertEquals(4L, df1.count());
- DataFrame df2 = context.read().text(
+ Dataset<Row> df2 = context.read().text(
Thread.currentThread().getContextClassLoader().getResource("text-suite.txt").toString(),
Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
Assert.assertEquals(5L, df2.count());
@@ -327,7 +327,7 @@ public class JavaDataFrameSuite {
@Test
public void testCountMinSketch() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
Assert.assertEquals(sketch1.totalCount(), 1000);
@@ -352,7 +352,7 @@ public class JavaDataFrameSuite {
@Test
public void testBloomFilter() {
- DataFrame df = context.range(1000);
+ Dataset<Row> df = context.range(1000);
BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03);
Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3);
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index b054b1095b..79b6e61767 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -169,7 +169,7 @@ public class JavaDatasetSuite implements Serializable {
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupByKey(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -217,7 +217,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
@@ -250,7 +250,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped =
- ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
+ ds.groupByKey(length(col("value"))).keyAs(Encoders.INT());
Dataset<String> mapped = grouped.mapGroups(
new MapGroupsFunction<Integer, String, String>() {
@@ -410,7 +410,7 @@ public class JavaDatasetSuite implements Serializable {
Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
- GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy(
+ GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
new MapFunction<Tuple2<String, Integer>, String>() {
@Override
public String call(Tuple2<String, Integer> value) throws Exception {
@@ -828,7 +828,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
SmallBean smallBean = new SmallBean();
@@ -845,7 +845,7 @@ public class JavaDatasetSuite implements Serializable {
{
Row row = new GenericRow(new Object[] { null });
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
NestedSmallBean nestedSmallBean = new NestedSmallBean();
@@ -862,7 +862,7 @@ public class JavaDatasetSuite implements Serializable {
})
});
- DataFrame df = context.createDataFrame(Collections.singletonList(row), schema);
+ Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema);
Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class));
ds.collect();
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 9e241f2098..0f9e453d26 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -42,9 +42,9 @@ public class JavaSaveLoadSuite {
String originalDefaultSource;
File path;
- DataFrame df;
+ Dataset<Row> df;
- private static void checkAnswer(DataFrame actual, List<Row> expected) {
+ private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
@@ -85,7 +85,7 @@ public class JavaSaveLoadSuite {
Map<String, String> options = new HashMap<>();
options.put("path", path.toString());
df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save();
- DataFrame loadedDF = sqlContext.read().format("json").options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").options(options).load();
checkAnswer(loadedDF, df.collectAsList());
}
@@ -98,7 +98,7 @@ public class JavaSaveLoadSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
+ Dataset<Row> loadedDF = sqlContext.read().format("json").schema(schema).options(options).load();
checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 26775c3700..f4a5107eaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -38,23 +38,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("analysis error should be eagerly reported") {
- // Eager analysis.
- withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") {
- intercept[Exception] { testData.select('nonExistentName) }
- intercept[Exception] {
- testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
- }
- intercept[Exception] {
- testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
- }
- intercept[Exception] {
- testData.groupBy($"abcd").agg(Map("key" -> "sum"))
- }
+ intercept[Exception] { testData.select('nonExistentName) }
+ intercept[Exception] {
+ testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
}
-
- // No more eager analysis once the flag is turned off
- withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") {
- testData.select('nonExistentName)
+ intercept[Exception] {
+ testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
+ }
+ intercept[Exception] {
+ testData.groupBy($"abcd").agg(Map("key" -> "sum"))
}
}
@@ -72,7 +64,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(1, 1) :: Nil)
}
- test("invalid plan toString, debug mode") {
+ ignore("invalid plan toString, debug mode") {
// Turn on debug mode so we can see invalid query plans.
import org.apache.spark.sql.execution.debug._
@@ -941,7 +933,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))
// error case: insert into an OneRowRelation
- new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row")
+ DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 3258f3782d..84770169f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -119,16 +119,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(sum(_._2)),
+ checkDataset(
+ ds.groupByKey(_._1).agg(sum(_._2)),
("a", 30), ("b", 3), ("c", 1))
}
test("typed aggregation: TypedAggregator, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(
+ checkDataset(
+ ds.groupByKey(_._1).agg(
sum(_._2),
expr("sum(_2)").as[Long],
count("*")),
@@ -138,8 +138,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: complex case") {
val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(
+ checkDataset(
+ ds.groupByKey(_._1).agg(
expr("avg(_2)").as[Double],
TypedAverage.toColumn),
("a", 2.0, 2.0), ("b", 3.0, 3.0))
@@ -148,8 +148,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: complex result type") {
val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(
+ checkDataset(
+ ds.groupByKey(_._1).agg(
expr("avg(_2)").as[Double],
ComplexResultAgg.toColumn),
("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
@@ -158,10 +158,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: in project list") {
val ds = Seq(1, 3, 2, 5).toDS()
- checkAnswer(
+ checkDataset(
ds.select(sum((i: Int) => i)),
11)
- checkAnswer(
+ checkDataset(
ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
11 -> 22)
}
@@ -169,7 +169,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: class input") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
- checkAnswer(
+ checkDataset(
ds.select(ClassInputAgg.toColumn),
3)
}
@@ -177,33 +177,33 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: class input with reordering") {
val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData]
- checkAnswer(
+ checkDataset(
ds.select(ClassInputAgg.toColumn),
1)
- checkAnswer(
+ checkDataset(
ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn),
(1.0, 1))
- checkAnswer(
- ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
+ checkDataset(
+ ds.groupByKey(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}
test("typed aggregation: complex input") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
- checkAnswer(
+ checkDataset(
ds.select(ComplexBufferAgg.toColumn),
2
)
- checkAnswer(
+ checkDataset(
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
(1.5, 2))
- checkAnswer(
- ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
+ checkDataset(
+ ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 848f1af655..2e5179a8d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -34,7 +34,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
// Make sure, the Dataset is indeed cached.
assertCached(cached)
// Check result.
- checkAnswer(
+ checkDataset(
cached,
2, 3, 4)
// Drop the cache.
@@ -52,7 +52,7 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
assertCached(ds2)
val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
- checkAnswer(joined, ("2", 2))
+ checkDataset(joined, ("2", 2))
assertCached(joined, 2)
ds1.unpersist()
@@ -63,11 +63,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext {
test("persist and then groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1").keyAs[String]
+ val grouped = ds.groupByKey($"_1").keyAs[String]
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
agged.persist()
- checkAnswer(
+ checkDataset(
agged.filter(_._1 == "b"),
("b", 3))
assertCached(agged.filter(_._1 == "b"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 243d13b19d..6e9840e4a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -28,14 +28,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("toDS") {
val data = Seq(1, 2, 3, 4, 5, 6)
- checkAnswer(
+ checkDataset(
data.toDS(),
data: _*)
}
test("as case class / collect") {
val ds = Seq(1, 2, 3).toDS().as[IntClass]
- checkAnswer(
+ checkDataset(
ds,
IntClass(1), IntClass(2), IntClass(3))
@@ -44,14 +44,14 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("map") {
val ds = Seq(1, 2, 3).toDS()
- checkAnswer(
+ checkDataset(
ds.map(_ + 1),
2, 3, 4)
}
test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
- checkAnswer(
+ checkDataset(
ds.filter(_ % 2 == 0),
2, 4)
}
@@ -77,54 +77,54 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
- val grouped = ds.groupBy(_ % 2)
- checkAnswer(
+ val grouped = ds.groupByKey(_ % 2)
+ checkDataset(
grouped.keys,
0, 1)
}
test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
- val grouped = ds.groupBy(_ % 2)
+ val grouped = ds.groupByKey(_ % 2)
val agged = grouped.mapGroups { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
(name, iter.size)
}
- checkAnswer(
+ checkDataset(
agged,
("even", 5), ("odd", 6))
}
test("groupBy function, flatMap") {
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
- val grouped = ds.groupBy(_.length)
+ val grouped = ds.groupByKey(_.length)
val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) }
- checkAnswer(
+ checkDataset(
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
test("Arrays and Lists") {
- checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
- checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
- checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
- checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
- checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
- checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
- checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
- checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
- checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
-
- checkAnswer(Seq(Array(1)).toDS(), Array(1))
- checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
- checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
- checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
- checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
- checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
- checkAnswer(Seq(Array(true)).toDS(), Array(true))
- checkAnswer(Seq(Array("test")).toDS(), Array("test"))
- checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
+ checkDataset(Seq(Seq(1)).toDS(), Seq(1))
+ checkDataset(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
+ checkDataset(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
+ checkDataset(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
+ checkDataset(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
+ checkDataset(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
+ checkDataset(Seq(Seq(true)).toDS(), Seq(true))
+ checkDataset(Seq(Seq("test")).toDS(), Seq("test"))
+ checkDataset(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
+
+ checkDataset(Seq(Array(1)).toDS(), Array(1))
+ checkDataset(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
+ checkDataset(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
+ checkDataset(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
+ checkDataset(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
+ checkDataset(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
+ checkDataset(Seq(Array(true)).toDS(), Array(true))
+ checkDataset(Seq(Array("test")).toDS(), Array("test"))
+ checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 79e10215f4..9f32c8bf95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,14 +34,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("toDS") {
val data = Seq(("a", 1), ("b", 2), ("c", 3))
- checkAnswer(
+ checkDataset(
data.toDS(),
data: _*)
}
test("toDS with RDD") {
val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS()
- checkAnswer(
+ checkDataset(
ds.mapPartitions(_ => Iterator(1)),
1, 1, 1)
}
@@ -71,26 +71,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = data.toDS()
assert(ds.repartition(10).rdd.partitions.length == 10)
- checkAnswer(
+ checkDataset(
ds.repartition(10),
data: _*)
assert(ds.coalesce(1).rdd.partitions.length == 1)
- checkAnswer(
+ checkDataset(
ds.coalesce(1),
data: _*)
}
test("as tuple") {
val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
- checkAnswer(
+ checkDataset(
data.as[(String, Int)],
("a", 1), ("b", 2))
}
test("as case class / collect") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
- checkAnswer(
+ checkDataset(
ds,
ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
assert(ds.collect().head == ClassData("a", 1))
@@ -108,7 +108,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("map") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.map(v => (v._1, v._2 + 1)),
("a", 2), ("b", 3), ("c", 4))
}
@@ -116,7 +116,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("map with type change with the exact matched number of attributes") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.map(identity[(String, Int)])
.as[OtherTuple]
.map(identity[OtherTuple]),
@@ -126,7 +126,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("map with type change with less attributes") {
val ds = Seq(("a", 1, 3), ("b", 2, 4), ("c", 3, 5)).toDS()
- checkAnswer(
+ checkDataset(
ds.as[OtherTuple]
.map(identity[OtherTuple]),
OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3))
@@ -137,23 +137,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
// when we implement better pipelining and local execution mode.
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
.map(c => ClassData(c.a, c.b + 1))
- .groupBy(p => p).count()
+ .groupByKey(p => p).count()
- checkAnswer(
+ checkDataset(
ds,
(ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
}
test("select") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.select(expr("_2 + 1").as[Int]),
2, 3, 4)
}
test("select 2") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.select(
expr("_1").as[String],
expr("_2").as[Int]) : Dataset[(String, Int)],
@@ -162,7 +162,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("select 2, primitive and tuple") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.select(
expr("_1").as[String],
expr("struct(_2, _2)").as[(Int, Int)]),
@@ -171,7 +171,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("select 2, primitive and class") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.select(
expr("_1").as[String],
expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
@@ -189,7 +189,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("filter") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
- checkAnswer(
+ checkDataset(
ds.filter(_._1 == "b"),
("b", 2))
}
@@ -217,7 +217,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")
- checkAnswer(
+ checkDataset(
ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"),
(1, 1), (2, 2))
}
@@ -230,7 +230,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(("a", new Integer(1)),
("b", new Integer(2))).toDS()
- checkAnswer(
+ checkDataset(
ds1.joinWith(ds2, $"_1" === $"a", "outer"),
(ClassNullableData("a", 1), ("a", new Integer(1))),
(ClassNullableData("c", 3), (nullString, nullInteger)),
@@ -241,7 +241,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()
- checkAnswer(
+ checkDataset(
ds1.joinWith(ds2, $"value" === $"_2"),
(1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
}
@@ -260,7 +260,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
- checkAnswer(
+ checkDataset(
ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
((("a", 1), ("a", 1)), ("a", 1)),
((("b", 2), ("b", 2)), ("b", 2)))
@@ -268,48 +268,48 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
- val grouped = ds.groupBy(v => (1, v._2))
- checkAnswer(
+ val grouped = ds.groupByKey(v => (1, v._2))
+ checkDataset(
grouped.keys,
(1, 1))
}
test("groupBy function, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy(v => (v._1, "word"))
+ val grouped = ds.groupByKey(v => (v._1, "word"))
val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) }
- checkAnswer(
+ checkDataset(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy function, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy(v => (v._1, "word"))
+ val grouped = ds.groupByKey(v => (v._1, "word"))
val agged = grouped.flatMapGroups { case (g, iter) =>
Iterator(g._1, iter.map(_._2).sum.toString)
}
- checkAnswer(
+ checkDataset(
agged,
"a", "30", "b", "3", "c", "1")
}
test("groupBy function, reduce") {
val ds = Seq("abc", "xyz", "hello").toDS()
- val agged = ds.groupBy(_.length).reduce(_ + _)
+ val agged = ds.groupByKey(_.length).reduce(_ + _)
- checkAnswer(
+ checkDataset(
agged,
3 -> "abcxyz", 5 -> "hello")
}
test("groupBy single field class, count") {
val ds = Seq("abc", "xyz", "hello").toDS()
- val count = ds.groupBy(s => Tuple1(s.length)).count()
+ val count = ds.groupByKey(s => Tuple1(s.length)).count()
- checkAnswer(
+ checkDataset(
count,
(Tuple1(3), 2L), (Tuple1(5), 1L)
)
@@ -317,49 +317,49 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1")
+ val grouped = ds.groupByKey($"_1")
val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
- checkAnswer(
+ checkDataset(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy columns, count") {
val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
- val count = ds.groupBy($"_1").count()
+ val count = ds.groupByKey($"_1").count()
- checkAnswer(
+ checkDataset(
count,
(Row("a"), 2L), (Row("b"), 1L))
}
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1").keyAs[String]
+ val grouped = ds.groupByKey($"_1").keyAs[String]
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
- checkAnswer(
+ checkDataset(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy columns asKey tuple, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)]
+ val grouped = ds.groupByKey($"_1", lit(1)).keyAs[(String, Int)]
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
- checkAnswer(
+ checkDataset(
agged,
(("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
}
test("groupBy columns asKey class, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
+ val grouped = ds.groupByKey($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
- checkAnswer(
+ checkDataset(
agged,
(ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
}
@@ -367,32 +367,32 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Long]),
+ checkDataset(
+ ds.groupByKey(_._1).agg(sum("_2").as[Long]),
("a", 30L), ("b", 3L), ("c", 1L))
}
test("typed aggregation: expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
+ checkDataset(
+ ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
}
test("typed aggregation: expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
+ checkDataset(
+ ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
}
test("typed aggregation: expr, expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
- checkAnswer(
- ds.groupBy(_._1).agg(
+ checkDataset(
+ ds.groupByKey(_._1).agg(
sum("_2").as[Long],
sum($"_2" + 1).as[Long],
count("*").as[Long],
@@ -403,11 +403,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("cogroup") {
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
- val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
+ val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) =>
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
}
- checkAnswer(
+ checkDataset(
cogrouped,
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}
@@ -415,11 +415,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("cogroup with complex data") {
val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS()
val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS()
- val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
+ val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { case (key, data1, data2) =>
Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
}
- checkAnswer(
+ checkDataset(
cogrouped,
1 -> "a", 2 -> "bc", 3 -> "d")
}
@@ -427,7 +427,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("sample with replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
- checkAnswer(
+ checkDataset(
data.sample(withReplacement = true, 0.05, seed = 13),
5, 10, 52, 73)
}
@@ -435,7 +435,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("sample without replacement") {
val n = 100
val data = sparkContext.parallelize(1 to n, 2).toDS()
- checkAnswer(
+ checkDataset(
data.sample(withReplacement = false, 0.05, seed = 13),
3, 17, 27, 58, 62)
}
@@ -445,13 +445,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(2, 3).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
- checkAnswer(joined, ("2", 2))
+ checkDataset(joined, ("2", 2))
}
test("self join") {
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true))
- checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
+ checkDataset(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2"))
}
test("toString") {
@@ -477,7 +477,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.groupBy(p => p).count().collect().toSet ==
+ assert(ds.groupByKey(p => p).count().collect().toSet ==
Set((KryoData(1), 1L), (KryoData(2), 1L)))
}
@@ -496,7 +496,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
- assert(ds.groupBy(p => p).count().collect().toSeq ==
+ assert(ds.groupByKey(p => p).count().collect().toSeq ==
Seq((JavaData(1), 1L), (JavaData(2), 1L)))
}
@@ -516,7 +516,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
- checkAnswer(
+ checkDataset(
ds1.joinWith(ds2, lit(true)),
((nullInt, "1"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
@@ -550,7 +550,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
}
- checkAnswer(
+ checkDataset(
buildDataset(Row(Row("hello", 1))),
NestedStruct(ClassData("hello", 1))
)
@@ -567,11 +567,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("SPARK-12478: top level null field") {
val ds0 = Seq(NestedStruct(null)).toDS()
- checkAnswer(ds0, NestedStruct(null))
+ checkDataset(ds0, NestedStruct(null))
checkAnswer(ds0.toDF(), Row(null))
val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
- checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+ checkDataset(ds1, DeepNestedStruct(NestedStruct(null)))
checkAnswer(ds1.toDF(), Row(Row(null)))
}
@@ -579,26 +579,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val outer = new OuterClass
OuterScopes.addOuterScope(outer)
val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS()
- checkAnswer(ds.map(_.a), "1", "2")
+ checkDataset(ds.map(_.a), "1", "2")
}
test("grouping key and grouped value has field with same name") {
val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS()
- val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups {
+ val agged = ds.groupByKey(d => ClassNullableData(d.a, null)).mapGroups {
case (key, values) => key.a + values.map(_.b).sum
}
- checkAnswer(agged, "a3")
+ checkDataset(agged, "a3")
}
test("cogroup's left and right side has field with same name") {
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS()
- val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) {
+ val cogrouped = left.groupByKey(_.a).cogroup(right.groupByKey(_.a)) {
case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum)
}
- checkAnswer(cogrouped, "a13", "b24")
+ checkDataset(cogrouped, "a13", "b24")
}
test("give nice error message when the real number of fields doesn't match encoder schema") {
@@ -626,13 +626,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("SPARK-13440: Resolving option fields") {
val df = Seq(1, 2, 3).toDS()
val ds = df.as[Option[Int]]
- checkAnswer(
+ checkDataset(
ds.filter(_ => true),
Some(1), Some(2), Some(3))
}
test("SPARK-13540 Dataset of nested class defined in Scala object") {
- checkAnswer(
+ checkDataset(
Seq(OuterObject.InnerClass("foo")).toDS(),
OuterObject.InnerClass("foo"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index c05aa5486a..855295d5f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -72,7 +72,7 @@ abstract class QueryTest extends PlanTest {
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
* which performs a subset of the checks done by this function.
*/
- protected def checkAnswer[T](
+ protected def checkDataset[T](
ds: Dataset[T],
expectedAnswer: T*): Unit = {
checkAnswer(
@@ -123,17 +123,17 @@ abstract class QueryTest extends PlanTest {
protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
val analyzedDF = try df catch {
case ae: AnalysisException =>
- val currentValue = sqlContext.conf.dataFrameEagerAnalysis
- sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
- val partiallyAnalzyedPlan = df.queryExecution.analyzed
- sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue)
- fail(
- s"""
- |Failed to analyze query: $ae
- |$partiallyAnalzyedPlan
- |
- |${stackTraceToString(ae)}
- |""".stripMargin)
+ if (ae.plan.isDefined) {
+ fail(
+ s"""
+ |Failed to analyze query: $ae
+ |${ae.plan.get}
+ |
+ |${stackTraceToString(ae)}
+ |""".stripMargin)
+ } else {
+ throw ae
+ }
}
checkJsonFormat(analyzedDF)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index bb5135826e..493a5a6437 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -65,9 +65,9 @@ import org.apache.spark.sql.execution.streaming._
trait StreamTest extends QueryTest with Timeouts {
implicit class RichSource(s: Source) {
- def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s))
+ def toDF(): DataFrame = DataFrame(sqlContext, StreamingRelation(s))
- def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s))
+ def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s))
}
/** How long to wait for an active stream to catch up when checking a result. */
@@ -168,10 +168,6 @@ trait StreamTest extends QueryTest with Timeouts {
}
}
- /** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */
- def testStream(stream: Dataset[_])(actions: StreamAction*): Unit =
- testStream(stream.toDF())(actions: _*)
-
/**
* Executes the specified actions on the given streaming DataFrame and provides helpful
* error messages in the case of failures or incorrect answers.
@@ -179,7 +175,8 @@ trait StreamTest extends QueryTest with Timeouts {
* Note that if the stream is not explicitly started before an action that requires it to be
* running then it will be automatically started before performing any other actions.
*/
- def testStream(stream: DataFrame)(actions: StreamAction*): Unit = {
+ def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = {
+ val stream = _stream.toDF()
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
index b4bf9eef8f..63fb4b7cf7 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -38,9 +38,9 @@ public class JavaDataFrameSuite {
private transient JavaSparkContext sc;
private transient HiveContext hc;
- DataFrame df;
+ Dataset<Row> df;
- private static void checkAnswer(DataFrame actual, List<Row> expected) {
+ private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
@@ -82,12 +82,12 @@ public class JavaDataFrameSuite {
@Test
public void testUDAF() {
- DataFrame df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value"));
+ Dataset<Row> df = hc.range(0, 100).unionAll(hc.range(0, 100)).select(col("id").as("value"));
UserDefinedAggregateFunction udaf = new MyDoubleSum();
UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf);
// Create Columns for the UDAF. For now, callUDF does not take an argument to specific if
// we want to use distinct aggregation.
- DataFrame aggregatedDF =
+ Dataset<Row> aggregatedDF =
df.groupBy()
.agg(
udaf.distinct(col("value")),
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
index 8c4af1b8ea..5a539eaec7 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -33,7 +33,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.QueryTest$;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.hive.test.TestHive$;
@@ -52,9 +52,9 @@ public class JavaMetastoreDataSourcesSuite {
File path;
Path hiveManagedPath;
FileSystem fs;
- DataFrame df;
+ Dataset<Row> df;
- private static void checkAnswer(DataFrame actual, List<Row> expected) {
+ private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
@@ -111,7 +111,7 @@ public class JavaMetastoreDataSourcesSuite {
sqlContext.sql("SELECT * FROM javaSavedTable"),
df.collectAsList());
- DataFrame loadedDF =
+ Dataset<Row> loadedDF =
sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options);
checkAnswer(loadedDF, df.collectAsList());
@@ -137,7 +137,7 @@ public class JavaMetastoreDataSourcesSuite {
List<StructField> fields = new ArrayList<>();
fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
StructType schema = DataTypes.createStructType(fields);
- DataFrame loadedDF =
+ Dataset<Row> loadedDF =
sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options);
checkAnswer(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
index 4adc5c1116..a0a0d134da 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala
@@ -63,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin)
}
- checkAnswer(sqlContext.sql(generatedSQL), new DataFrame(sqlContext, plan))
+ checkAnswer(sqlContext.sql(generatedSQL), DataFrame(sqlContext, plan))
}
protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 45634a4475..d5a4295d62 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -128,6 +128,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
import testImplicits._
override def beforeAll(): Unit = {
+ super.beforeAll()
val data1 = Seq[(Integer, Integer)](
(1, 10),
(null, -60),