aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java127
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java111
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala110
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala101
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala7
6 files changed, 457 insertions, 5 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
new file mode 100644
index 0000000000..3b156fa048
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.Model;
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.Tokenizer;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.tuning.CrossValidator;
+import org.apache.spark.ml.tuning.CrossValidatorModel;
+import org.apache.spark.ml.tuning.ParamGridBuilder;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+
+/**
+ * A simple example demonstrating model selection using CrossValidator.
+ * This example also demonstrates how Pipelines are Estimators.
+ *
+ * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and
+ * {@link org.apache.spark.examples.ml.Document} defined in the Scala example
+ * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}.
+ *
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaCrossValidatorExample
+ * </pre>
+ */
+public class JavaCrossValidatorExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ JavaSQLContext jsql = new JavaSQLContext(jsc);
+
+ // Prepare training documents, which are labeled.
+ List<LabeledDocument> localTraining = Lists.newArrayList(
+ new LabeledDocument(0L, "a b c d e spark", 1.0),
+ new LabeledDocument(1L, "b d", 0.0),
+ new LabeledDocument(2L, "spark f g h", 1.0),
+ new LabeledDocument(3L, "hadoop mapreduce", 0.0),
+ new LabeledDocument(4L, "b spark who", 1.0),
+ new LabeledDocument(5L, "g d a y", 0.0),
+ new LabeledDocument(6L, "spark fly", 1.0),
+ new LabeledDocument(7L, "was mapreduce", 0.0),
+ new LabeledDocument(8L, "e spark program", 1.0),
+ new LabeledDocument(9L, "a e c l", 0.0),
+ new LabeledDocument(10L, "spark compile", 1.0),
+ new LabeledDocument(11L, "hadoop software", 0.0));
+ JavaSchemaRDD training =
+ jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words");
+ HashingTF hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol())
+ .setOutputCol("features");
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01);
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
+
+ // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ // This will allow us to jointly choose parameters for all Pipeline stages.
+ // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ CrossValidator crossval = new CrossValidator()
+ .setEstimator(pipeline)
+ .setEvaluator(new BinaryClassificationEvaluator());
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ ParamMap[] paramGrid = new ParamGridBuilder()
+ .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000})
+ .addGrid(lr.regParam(), new double[]{0.1, 0.01})
+ .build();
+ crossval.setEstimatorParamMaps(paramGrid);
+ crossval.setNumFolds(2); // Use 3+ in practice
+
+ // Run cross-validation, and choose the best set of parameters.
+ CrossValidatorModel cvModel = crossval.fit(training);
+
+ // Prepare test documents, which are unlabeled.
+ List<Document> localTest = Lists.newArrayList(
+ new Document(4L, "spark i j k"),
+ new Document(5L, "l m n"),
+ new Document(6L, "mapreduce spark"),
+ new Document(7L, "apache hadoop"));
+ JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+
+ // Make predictions on test documents. cvModel uses the best model found (lrModel).
+ cvModel.transform(test).registerAsTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ for (Row r: predictions.collect()) {
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ + ", prediction=" + r.get(3));
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
new file mode 100644
index 0000000000..cf58f4dfaa
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegressionModel;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+
+/**
+ * A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+ * Run with
+ * {{{
+ * bin/run-example ml.JavaSimpleParamsExample
+ * }}}
+ */
+public class JavaSimpleParamsExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ JavaSQLContext jsql = new JavaSQLContext(jsc);
+
+ // Prepare training data.
+ // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
+ // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ List<LabeledPoint> localTraining = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
+ JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ LogisticRegression lr = new LogisticRegression();
+ // Print out the parameters, documentation, and any default values.
+ System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10)
+ .setRegParam(0.01);
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ LogisticRegressionModel model1 = lr.fit(training);
+ // Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ // we can view the parameters it used during fit().
+ // This prints the parameter (name: value) pairs, where names are unique IDs for this
+ // LogisticRegression instance.
+ System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+
+ // We may alternatively specify parameters using a ParamMap.
+ ParamMap paramMap = new ParamMap();
+ paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
+ paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
+ paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
+
+ // One can also combine ParamMaps.
+ ParamMap paramMap2 = new ParamMap();
+ paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
+ ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
+
+ // Now learn a new model using the paramMapCombined parameters.
+ // paramMapCombined overrides all parameters set earlier via lr.set* methods.
+ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
+ System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+
+ // Prepare test documents.
+ List<LabeledPoint> localTest = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
+ JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+
+ // Make predictions on test documents using the Transformer.transform() method.
+ // LogisticRegression.transform will only use the 'features' column.
+ // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
+ // column since we renamed the lr.scoreCol parameter previously.
+ model2.transform(test).registerAsTable("results");
+ JavaSchemaRDD results =
+ jsql.sql("SELECT features, label, probability, prediction FROM results");
+ for (Row r: results.collect()) {
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ + ", prediction=" + r.get(3));
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 22ba68d8c3..54f18014e4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -80,14 +80,14 @@ public class JavaSimpleTextClassificationPipeline {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- JavaSchemaRDD test =
- jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
model.transform(test).registerAsTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
- System.out.println(r);
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ + ", prediction=" + r.get(3));
}
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
new file mode 100644
index 0000000000..ce6bc066bd
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * A simple example demonstrating model selection using CrossValidator.
+ * This example also demonstrates how Pipelines are Estimators.
+ *
+ * This example uses the [[LabeledDocument]] and [[Document]] case classes from
+ * [[SimpleTextClassificationPipeline]].
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.CrossValidatorExample
+ * }}}
+ */
+object CrossValidatorExample {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("CrossValidatorExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Prepare training documents, which are labeled.
+ val training = sparkContext.parallelize(Seq(
+ LabeledDocument(0L, "a b c d e spark", 1.0),
+ LabeledDocument(1L, "b d", 0.0),
+ LabeledDocument(2L, "spark f g h", 1.0),
+ LabeledDocument(3L, "hadoop mapreduce", 0.0),
+ LabeledDocument(4L, "b spark who", 1.0),
+ LabeledDocument(5L, "g d a y", 0.0),
+ LabeledDocument(6L, "spark fly", 1.0),
+ LabeledDocument(7L, "was mapreduce", 0.0),
+ LabeledDocument(8L, "e spark program", 1.0),
+ LabeledDocument(9L, "a e c l", 0.0),
+ LabeledDocument(10L, "spark compile", 1.0),
+ LabeledDocument(11L, "hadoop software", 0.0)))
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ val tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words")
+ val hashingTF = new HashingTF()
+ .setInputCol(tokenizer.getOutputCol)
+ .setOutputCol("features")
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, hashingTF, lr))
+
+ // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
+ // This will allow us to jointly choose parameters for all Pipeline stages.
+ // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ val crossval = new CrossValidator()
+ .setEstimator(pipeline)
+ .setEvaluator(new BinaryClassificationEvaluator)
+ // We use a ParamGridBuilder to construct a grid of parameters to search over.
+ // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
+ // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
+ val paramGrid = new ParamGridBuilder()
+ .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
+ .addGrid(lr.regParam, Array(0.1, 0.01))
+ .build()
+ crossval.setEstimatorParamMaps(paramGrid)
+ crossval.setNumFolds(2) // Use 3+ in practice
+
+ // Run cross-validation, and choose the best set of parameters.
+ val cvModel = crossval.fit(training)
+
+ // Prepare test documents, which are unlabeled.
+ val test = sparkContext.parallelize(Seq(
+ Document(4L, "spark i j k"),
+ Document(5L, "l m n"),
+ Document(6L, "mapreduce spark"),
+ Document(7L, "apache hadoop")))
+
+ // Make predictions on test documents. cvModel uses the best model found (lrModel).
+ cvModel.transform(test)
+ .select('id, 'text, 'score, 'prediction)
+ .collect()
+ .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
+ println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ }
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
new file mode 100644
index 0000000000..44d5b084c2
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * A simple example demonstrating ways to specify parameters for Estimators and Transformers.
+ * Run with
+ * {{{
+ * bin/run-example ml.SimpleParamsExample
+ * }}}
+ */
+object SimpleParamsExample {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("SimpleParamsExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Prepare training data.
+ // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
+ // into SchemaRDDs, where it uses the bean metadata to infer the schema.
+ val training = sparkContext.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ val lr = new LogisticRegression()
+ // Print out the parameters, documentation, and any default values.
+ println("LogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10)
+ .setRegParam(0.01)
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ val model1 = lr.fit(training)
+ // Since model1 is a Model (i.e., a Transformer produced by an Estimator),
+ // we can view the parameters it used during fit().
+ // This prints the parameter (name: value) pairs, where names are unique IDs for this
+ // LogisticRegression instance.
+ println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+
+ // We may alternatively specify parameters using a ParamMap,
+ // which supports several methods for specifying parameters.
+ val paramMap = ParamMap(lr.maxIter -> 20)
+ paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
+ paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
+
+ // One can also combine ParamMaps.
+ val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
+ val paramMapCombined = paramMap ++ paramMap2
+
+ // Now learn a new model using the paramMapCombined parameters.
+ // paramMapCombined overrides all parameters set earlier via lr.set* methods.
+ val model2 = lr.fit(training, paramMapCombined)
+ println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+
+ // Prepare test documents.
+ val test = sparkContext.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
+
+ // Make predictions on test documents using the Transformer.transform() method.
+ // LogisticRegression.transform will only use the 'features' column.
+ // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
+ // column since we renamed the lr.scoreCol parameter previously.
+ model2.transform(test)
+ .select('features, 'label, 'probability, 'prediction)
+ .collect()
+ .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
+ println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
+ }
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index ee7897d906..92895a05e4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -20,10 +20,11 @@ package org.apache.spark.examples.ml
import scala.beans.BeanInfo
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{Row, SQLContext}
@BeanInfo
case class LabeledDocument(id: Long, text: String, label: Double)
@@ -81,6 +82,8 @@ object SimpleTextClassificationPipeline {
model.transform(test)
.select('id, 'text, 'score, 'prediction)
.collect()
- .foreach(println)
+ .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
+ println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ }
}
}