diff options
Diffstat (limited to 'examples')
6 files changed, 457 insertions, 5 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java new file mode 100644 index 0000000000..3b156fa048 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Model; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; + +/** + * A simple example demonstrating model selection using CrossValidator. + * This example also demonstrates how Pipelines are Estimators. + * + * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and + * {@link org.apache.spark.examples.ml.Document} defined in the Scala example + * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}. + * + * Run with + * <pre> + * bin/run-example ml.JavaCrossValidatorExample + * </pre> + */ +public class JavaCrossValidatorExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training documents, which are labeled. + List<LabeledDocument> localTraining = Lists.newArrayList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0), + new LabeledDocument(4L, "b spark who", 1.0), + new LabeledDocument(5L, "g d a y", 0.0), + new LabeledDocument(6L, "spark fly", 1.0), + new LabeledDocument(7L, "was mapreduce", 0.0), + new LabeledDocument(8L, "e spark program", 1.0), + new LabeledDocument(9L, "a e c l", 0.0), + new LabeledDocument(10L, "spark compile", 1.0), + new LabeledDocument(11L, "hadoop software", 0.0)); + JavaSchemaRDD training = + jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + CrossValidator crossval = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()); + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) + .addGrid(lr.regParam(), new double[]{0.1, 0.01}) + .build(); + crossval.setEstimatorParamMaps(paramGrid); + crossval.setNumFolds(2); // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + CrossValidatorModel cvModel = crossval.fit(training); + + // Prepare test documents, which are unlabeled. + List<Document> localTest = Lists.newArrayList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop")); + JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + cvModel.transform(test).registerAsTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + + ", prediction=" + r.get(3)); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java new file mode 100644 index 0000000000..cf58f4dfaa --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; + +/** + * A simple example demonstrating ways to specify parameters for Estimators and Transformers. + * Run with + * {{{ + * bin/run-example ml.JavaSimpleParamsExample + * }}} + */ +public class JavaSimpleParamsExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training data. + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans + // into SchemaRDDs, where it uses the bean metadata to infer the schema. + List<LabeledPoint> localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + LogisticRegression lr = new LogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10) + .setRegParam(0.01); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + LogisticRegressionModel model1 = lr.fit(training); + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap()); + + // We may alternatively specify parameters using a ParamMap. + ParamMap paramMap = new ParamMap(); + paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. + paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. + paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + + // One can also combine ParamMaps. + ParamMap paramMap2 = new ParamMap(); + paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); + System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap()); + + // Prepare test documents. + List<LabeledPoint> localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' + // column since we renamed the lr.scoreCol parameter previously. + model2.transform(test).registerAsTable("results"); + JavaSchemaRDD results = + jsql.sql("SELECT features, label, probability, prediction FROM results"); + for (Row r: results.collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 22ba68d8c3..54f18014e4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -80,14 +80,14 @@ public class JavaSimpleTextClassificationPipeline { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - JavaSchemaRDD test = - jsql.applySchema(jsc.parallelize(localTest), Document.class); + JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. model.transform(test).registerAsTable("prediction"); JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { - System.out.println(r); + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + + ", prediction=" + r.get(3)); } } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala new file mode 100644 index 0000000000..ce6bc066bd --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.sql.{Row, SQLContext} + +/** + * A simple example demonstrating model selection using CrossValidator. + * This example also demonstrates how Pipelines are Estimators. + * + * This example uses the [[LabeledDocument]] and [[Document]] case classes from + * [[SimpleTextClassificationPipeline]]. + * + * Run with + * {{{ + * bin/run-example ml.CrossValidatorExample + * }}} + */ +object CrossValidatorExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("CrossValidatorExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training documents, which are labeled. + val training = sparkContext.parallelize(Seq( + LabeledDocument(0L, "a b c d e spark", 1.0), + LabeledDocument(1L, "b d", 0.0), + LabeledDocument(2L, "spark f g h", 1.0), + LabeledDocument(3L, "hadoop mapreduce", 0.0), + LabeledDocument(4L, "b spark who", 1.0), + LabeledDocument(5L, "g d a y", 0.0), + LabeledDocument(6L, "spark fly", 1.0), + LabeledDocument(7L, "was mapreduce", 0.0), + LabeledDocument(8L, "e spark program", 1.0), + LabeledDocument(9L, "a e c l", 0.0), + LabeledDocument(10L, "spark compile", 1.0), + LabeledDocument(11L, "hadoop software", 0.0))) + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val crossval = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + val paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) + .addGrid(lr.regParam, Array(0.1, 0.01)) + .build() + crossval.setEstimatorParamMaps(paramGrid) + crossval.setNumFolds(2) // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + val cvModel = crossval.fit(training) + + // Prepare test documents, which are unlabeled. + val test = sparkContext.parallelize(Seq( + Document(4L, "spark i j k"), + Document(5L, "l m n"), + Document(6L, "mapreduce spark"), + Document(7L, "apache hadoop"))) + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + cvModel.transform(test) + .select('id, 'text, 'score, 'prediction) + .collect() + .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => + println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala new file mode 100644 index 0000000000..44d5b084c2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{Row, SQLContext} + +/** + * A simple example demonstrating ways to specify parameters for Estimators and Transformers. + * Run with + * {{{ + * bin/run-example ml.SimpleParamsExample + * }}} + */ +object SimpleParamsExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SimpleParamsExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training data. + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans + // into SchemaRDDs, where it uses the bean metadata to infer the schema. + val training = sparkContext.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new LogisticRegression() + // Print out the parameters, documentation, and any default values. + println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + .setRegParam(0.01) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model1 = lr.fit(training) + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + println("Model 1 was fit using parameters: " + model1.fittingParamMap) + + // We may alternatively specify parameters using a ParamMap, + // which supports several methods for specifying parameters. + val paramMap = ParamMap(lr.maxIter -> 20) + paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + + // One can also combine ParamMaps. + val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMapCombined = paramMap ++ paramMap2 + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + val model2 = lr.fit(training, paramMapCombined) + println("Model 2 was fit using parameters: " + model2.fittingParamMap) + + // Prepare test documents. + val test = sparkContext.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' + // column since we renamed the lr.scoreCol parameter previously. + model2.transform(test) + .select('features, 'label, 'probability, 'prediction) + .collect() + .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => + println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index ee7897d906..92895a05e4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,10 +20,11 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{Row, SQLContext} @BeanInfo case class LabeledDocument(id: Long, text: String, label: Double) @@ -81,6 +82,8 @@ object SimpleTextClassificationPipeline { model.transform(test) .select('id, 'text, 'score, 'prediction) .collect() - .foreach(println) + .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => + println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + } } } |