diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-11-13 08:43:05 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-13 08:43:05 -0800 |
commit | 99693fef0a30432d94556154b81872356d921c64 (patch) | |
tree | 09d76cc0ef6cae153718982a9a1ecc827ee12d5f /examples/src/main/java | |
parent | 61a28486ccbcdd37461419df958aea222c8b9f09 (diff) | |
download | spark-99693fef0a30432d94556154b81872356d921c64.tar.gz spark-99693fef0a30432d94556154b81872356d921c64.tar.bz2 spark-99693fef0a30432d94556154b81872356d921c64.zip |
[SPARK-11723][ML][DOC] Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame
Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame, include:
* Use libSVM data source for all example codes under examples/ml, and remove unused import.
* Use libSVM data source for user guides under ml-*** which were omitted by #8697.
* Fix bug: We should use ```sqlContext.read().format("libsvm").load(path)``` at Java side, but the API doc and user guides misuse as ```sqlContext.read.format("libsvm").load(path)```.
* Code cleanup.
mengxr
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9690 from yanboliang/spark-11723.
Diffstat (limited to 'examples/src/main/java')
5 files changed, 17 insertions, 35 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java index 51c1730a8a..482225e585 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -26,9 +26,6 @@ import org.apache.spark.ml.classification.DecisionTreeClassifier; import org.apache.spark.ml.classification.DecisionTreeClassificationModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; // $example off$ @@ -40,9 +37,8 @@ public class JavaDecisionTreeClassificationExample { SQLContext sqlContext = new SQLContext(jsc); // $example on$ - // Load and parse the data file, converting it to a DataFrame. - RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); - DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java index a4098a4233..c7f1868dd1 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -27,9 +27,6 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.DecisionTreeRegressionModel; import org.apache.spark.ml.regression.DecisionTreeRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; // $example off$ @@ -40,9 +37,9 @@ public class JavaDecisionTreeRegressionExample { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); // $example on$ - // Load and parse the data file, converting it to a DataFrame. - RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); - DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java index f48e1339c5..84369f6681 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java @@ -21,12 +21,9 @@ package org.apache.spark.examples.ml; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; // $example off$ @@ -43,8 +40,7 @@ public class JavaMultilayerPerceptronClassifierExample { // $example on$ // Load training data String path = "data/mllib/sample_multiclass_classification_data.txt"; - JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD(); - DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); + DataFrame dataFrame = jsql.read().format("libsvm").load(path); // Split the data into train and test DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); DataFrame train = splits[0]; diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index e7f2f6f615..f0d92a56be 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -27,9 +27,7 @@ import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.ml.util.MetadataUtils; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; @@ -80,31 +78,30 @@ public class JavaOneVsRestExample { OneVsRest ovr = new OneVsRest().setClassifier(classifier); String input = params.input; - RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); - RDD<LabeledPoint> train; - RDD<LabeledPoint> test; + DataFrame inputData = jsql.read().format("libsvm").load(input); + DataFrame train; + DataFrame test; // compute the train/ test split: if testInput is not provided use part of input String testInput = params.testInput; if (testInput != null) { train = inputData; // compute the number of features in the training set. - int numFeatures = inputData.first().features().size(); - test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + int numFeatures = inputData.first().<Vector>getAs(1).size(); + test = jsql.read().format("libsvm").option("numFeatures", + String.valueOf(numFeatures)).load(testInput); } else { double f = params.fracTest; - RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); train = tmp[0]; test = tmp[1]; } // train the multiclass model - DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); - OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + OneVsRestModel ovrModel = ovr.fit(train.cache()); // score the model on test data - DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); - DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + DataFrame predictions = ovrModel.transform(test.cache()) .select("prediction", "label"); // obtain metrics diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java index 23f834ab43..d433905fc8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -23,8 +23,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.tuning.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -46,9 +44,7 @@ public class JavaTrainValidationSplitExample { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); - DataFrame data = jsql.createDataFrame( - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), - LabeledPoint.class); + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Prepare training and test data. DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); |