aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-07-30 08:08:33 -0700
committerSean Owen <sowen@cloudera.com>2016-07-30 08:08:33 -0700
commita6290e51e402e8434d6207d553db1f551e714fde (patch)
tree02f0ae23e903fd3f3d317fc5a3bbcb2ca863813b /examples/src/main
parentbce354c1d4e2b97b1159913085e9883a26bc605a (diff)
downloadspark-a6290e51e402e8434d6207d553db1f551e714fde.tar.gz
spark-a6290e51e402e8434d6207d553db1f551e714fde.tar.bz2
spark-a6290e51e402e8434d6207d553db1f551e714fde.zip
[SPARK-16800][EXAMPLES][ML] Fix Java examples that fail to run due to exception
## What changes were proposed in this pull request? Some Java examples are using mllib.linalg.Vectors instead of ml.linalg.Vectors and causes an exception when run. Also there are some Java examples that incorrectly specify data types in the schema, also causing an exception. ## How was this patch tested? Ran corrected examples locally Author: Bryan Cutler <cutlerb@gmail.com> Closes #14405 from BryanCutler/java-examples-ml.Vectors-fix-SPARK-16800.
Diffstat (limited to 'examples/src/main')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java2
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java43
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java2
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java2
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java8
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java2
12 files changed, 49 insertions, 38 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
index b0115756cf..3f034588c9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
@@ -23,12 +23,16 @@ import java.util.List;
import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
-import org.apache.spark.mllib.linalg.*;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
// $example off$
/**
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
index 5f964aca92..a954dbd20c 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
@@ -47,7 +47,7 @@ public class JavaBinarizerExample {
RowFactory.create(2, 0.2)
);
StructType schema = new StructType(new StructField[]{
- new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
index f8f2fb14be..fcf90d8d18 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
@@ -25,8 +25,8 @@ import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.ChiSqSelector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
index eee92c77a8..66ce23b49d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
@@ -25,8 +25,8 @@ import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.DCT;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.Metadata;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
index 889f5785df..9e07a0c2f8 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java
@@ -19,16 +19,20 @@ package org.apache.spark.examples.ml;
// $example on$
import java.util.Arrays;
-// $example off$
+import java.util.List;
-// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
// $example off$
import org.apache.spark.sql.SparkSession;
@@ -44,15 +48,17 @@ public class JavaEstimatorTransformerParamExample {
// $example on$
// Prepare training data.
- // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into
- // DataFrames, where it uses the bean metadata to infer the schema.
- Dataset<Row> training = spark.createDataFrame(
- Arrays.asList(
- new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
- new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
- new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
- new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))
- ), LabeledPoint.class);
+ List<Row> dataTraining = Arrays.asList(
+ RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5))
+ );
+ StructType schema = new StructType(new StructField[]{
+ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+ Dataset<Row> training = spark.createDataFrame(dataTraining, schema);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -87,11 +93,12 @@ public class JavaEstimatorTransformerParamExample {
System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
- Dataset<Row> test = spark.createDataFrame(Arrays.asList(
- new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
- new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
- new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
- ), LabeledPoint.class);
+ List<Row> dataTest = Arrays.asList(
+ RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5))
+ );
+ Dataset<Row> test = spark.createDataFrame(dataTest, schema);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
index dcd209e28e..a561b6d39b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
@@ -21,7 +21,7 @@ package org.apache.spark.examples.ml;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
index 5d29e54549..a15e5f84a1 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
@@ -53,7 +53,7 @@ public class JavaOneHotEncoderExample {
);
StructType schema = new StructType(new StructField[]{
- new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
index ffa979ee01..d597a9a2ed 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
@@ -25,8 +25,8 @@ import java.util.List;
import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.PCAModel;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
index 7afcd0e50c..67180df65c 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
@@ -24,8 +24,8 @@ import java.util.Arrays;
import java.util.List;
import org.apache.spark.ml.feature.PolynomialExpansion;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
index 6e0753959e..800e42c949 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
@@ -25,7 +25,7 @@ import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
@@ -45,9 +45,9 @@ public class JavaTfIdfExample {
// $example on$
List<Row> data = Arrays.asList(
- RowFactory.create(0, "Hi I heard about Spark"),
- RowFactory.create(0, "I wish Java could use case classes"),
- RowFactory.create(1, "Logistic regression models are neat")
+ RowFactory.create(0.0, "Hi I heard about Spark"),
+ RowFactory.create(0.0, "I wish Java could use case classes"),
+ RowFactory.create(1.0, "Logistic regression models are neat")
);
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
index 41f1d8750a..9bb0f93d3a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
@@ -23,8 +23,8 @@ import org.apache.spark.sql.SparkSession;
import java.util.Arrays;
import org.apache.spark.ml.feature.VectorAssembler;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
index 24959c0e10..19b8bc83be 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
@@ -28,7 +28,7 @@ import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.feature.VectorSlicer;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;