aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-13 08:43:05 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-13 08:43:05 -0800
commit99693fef0a30432d94556154b81872356d921c64 (patch)
tree09d76cc0ef6cae153718982a9a1ecc827ee12d5f /examples/src/main/java
parent61a28486ccbcdd37461419df958aea222c8b9f09 (diff)
downloadspark-99693fef0a30432d94556154b81872356d921c64.tar.gz
spark-99693fef0a30432d94556154b81872356d921c64.tar.bz2
spark-99693fef0a30432d94556154b81872356d921c64.zip
[SPARK-11723][ML][DOC] Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame
Use LibSVM data source rather than MLUtils.loadLibSVMFile to load DataFrame, include: * Use libSVM data source for all example codes under examples/ml, and remove unused import. * Use libSVM data source for user guides under ml-*** which were omitted by #8697. * Fix bug: We should use ```sqlContext.read().format("libsvm").load(path)``` at Java side, but the API doc and user guides misuse as ```sqlContext.read.format("libsvm").load(path)```. * Code cleanup. mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #9690 from yanboliang/spark-11723.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java9
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java23
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java6
5 files changed, 17 insertions, 35 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
index 51c1730a8a..482225e585 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
@@ -26,9 +26,6 @@ import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.*;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -40,9 +37,8 @@ public class JavaDecisionTreeClassificationExample {
SQLContext sqlContext = new SQLContext(jsc);
// $example on$
- // Load and parse the data file, converting it to a DataFrame.
- RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
- DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
+ // Load the data stored in LIBSVM format as a DataFrame.
+ DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
index a4098a4233..c7f1868dd1 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java
@@ -27,9 +27,6 @@ import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
// $example off$
@@ -40,9 +37,9 @@ public class JavaDecisionTreeRegressionExample {
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
// $example on$
- // Load and parse the data file, converting it to a DataFrame.
- RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
- DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
+ // Load the data stored in LIBSVM format as a DataFrame.
+ DataFrame data = sqlContext.read().format("libsvm")
+ .load("data/mllib/sample_libsvm_data.txt");
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
index f48e1339c5..84369f6681 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
@@ -21,12 +21,9 @@ package org.apache.spark.examples.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
-import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.DataFrame;
// $example off$
@@ -43,8 +40,7 @@ public class JavaMultilayerPerceptronClassifierExample {
// $example on$
// Load training data
String path = "data/mllib/sample_multiclass_classification_data.txt";
- JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
- DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
+ DataFrame dataFrame = jsql.read().format("libsvm").load(path);
// Split the data into train and test
DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
DataFrame train = splits[0];
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
index e7f2f6f615..f0d92a56be 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -27,9 +27,7 @@ import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.util.MetadataUtils;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
-import org.apache.spark.rdd.RDD;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;
@@ -80,31 +78,30 @@ public class JavaOneVsRestExample {
OneVsRest ovr = new OneVsRest().setClassifier(classifier);
String input = params.input;
- RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
- RDD<LabeledPoint> train;
- RDD<LabeledPoint> test;
+ DataFrame inputData = jsql.read().format("libsvm").load(input);
+ DataFrame train;
+ DataFrame test;
// compute the train/ test split: if testInput is not provided use part of input
String testInput = params.testInput;
if (testInput != null) {
train = inputData;
// compute the number of features in the training set.
- int numFeatures = inputData.first().features().size();
- test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
+ int numFeatures = inputData.first().<Vector>getAs(1).size();
+ test = jsql.read().format("libsvm").option("numFeatures",
+ String.valueOf(numFeatures)).load(testInput);
} else {
double f = params.fracTest;
- RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
+ DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
train = tmp[0];
test = tmp[1];
}
// train the multiclass model
- DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
- OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
+ OneVsRestModel ovrModel = ovr.fit(train.cache());
// score the model on test data
- DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
- DataFrame predictions = ovrModel.transform(testDataFrame.cache())
+ DataFrame predictions = ovrModel.transform(test.cache())
.select("prediction", "label");
// obtain metrics
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
index 23f834ab43..d433905fc8 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
@@ -23,8 +23,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.tuning.*;
-import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
@@ -46,9 +44,7 @@ public class JavaTrainValidationSplitExample {
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);
- DataFrame data = jsql.createDataFrame(
- MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
- LabeledPoint.class);
+ DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Prepare training and test data.
DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);