diff options
author | Devaraj K <devaraj@apache.org> | 2016-02-22 17:21:37 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-22 17:21:37 -0800 |
commit | 02b1fefffb00d50c1076a26f2f3f41f3c1fa0001 (patch) | |
tree | d0012790986cca246579ce1d4a8b583fff47469a /examples/src/main/java | |
parent | 9f410871ca03f4c04bd965b2e4f80167ce543139 (diff) | |
download | spark-02b1fefffb00d50c1076a26f2f3f41f3c1fa0001.tar.gz spark-02b1fefffb00d50c1076a26f2f3f41f3c1fa0001.tar.bz2 spark-02b1fefffb00d50c1076a26f2f3f41f3c1fa0001.zip |
[SPARK-13012][DOCUMENTATION] Replace example code in ml-guide.md using include_example
Replaced example code in ml-guide.md using include_example
Author: Devaraj K <devaraj@apache.org>
Closes #11053 from devaraj-kavali/SPARK-13012.
Diffstat (limited to 'examples/src/main/java')
6 files changed, 488 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java new file mode 100644 index 0000000000..6459dabc06 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDocument.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Unlabeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaDocument implements Serializable { + + private long id; + private String text; + + public JavaDocument(long id, String text) { + this.id = id; + this.text = text; + } + + public long getId() { + return this.id; + } + + public String getText() { + return this.text; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java new file mode 100644 index 0000000000..44cf3507f3 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Estimator, Transformer, and Param. + */ +public class JavaEstimatorTransformerParamExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaEstimatorTransformerParamExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training data. + // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into + // DataFrames, where it uses the bean metadata to infer the schema. + DataFrame training = sqlContext.createDataFrame( + Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + LogisticRegression lr = new LogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10).setRegParam(0.01); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + LogisticRegressionModel model1 = lr.fit(training); + // Since model1 is a Model (i.e., a Transformer produced by an Estimator), + // we can view the parameters it used during fit(). + // This prints the parameter (name: value) pairs, where names are unique IDs for this + // LogisticRegression instance. + System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); + + // We may alternatively specify parameters using a ParamMap. + ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + + // One can also combine ParamMaps. + ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name + ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); + + // Now learn a new model using the paramMapCombined parameters. + // paramMapCombined overrides all parameters set earlier via lr.set* methods. + LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); + System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); + + // Prepare test documents. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ), LabeledPoint.class); + + // Make predictions on test documents using the Transformer.transform() method. + // LogisticRegression.transform will only use the 'features' column. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. + DataFrame results = model2.transform(test); + for (Row r : results.select("features", "label", "myProbability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java new file mode 100644 index 0000000000..68d1caf6ad --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLabeledDocument.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.io.Serializable; + +/** + * Labeled instance type, Spark SQL can infer schema from Java Beans. + */ +@SuppressWarnings("serial") +public class JavaLabeledDocument extends JavaDocument implements Serializable { + + private double label; + + public JavaLabeledDocument(long id, String text, double label) { + super(id, text); + this.label = label; + } + + public double getLabel() { + return this.label; + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java new file mode 100644 index 0000000000..87ad119491 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.CrossValidator; +import org.apache.spark.ml.tuning.CrossValidatorModel; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Cross Validation. + */ +public class JavaModelSelectionViaCrossValidationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaCrossValidationExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L,"spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0), + new JavaLabeledDocument(4L, "b spark who", 1.0), + new JavaLabeledDocument(5L, "g d a y", 0.0), + new JavaLabeledDocument(6L, "spark fly", 1.0), + new JavaLabeledDocument(7L, "was mapreduce", 0.0), + new JavaLabeledDocument(8L, "e spark program", 1.0), + new JavaLabeledDocument(9L, "a e c l", 0.0), + new JavaLabeledDocument(10L, "spark compile", 1.0), + new JavaLabeledDocument(11L, "hadoop software", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000}) + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .build(); + + // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + // This will allow us to jointly choose parameters for all Pipeline stages. + // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric + // is areaUnderROC. + CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice + + // Run cross-validation, and choose the best set of parameters. + CrossValidatorModel cvModel = cv.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame predictions = cvModel.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java new file mode 100644 index 0000000000..77adb02dfd --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.ml.tuning.TrainValidationSplit; +import org.apache.spark.ml.tuning.TrainValidationSplitModel; +import org.apache.spark.sql.DataFrame; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for Model Selection via Train Validation Split. + */ +public class JavaModelSelectionViaTrainValidationSplitExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("JavaModelSelectionViaTrainValidationSplitExample"); + SparkContext sc = new SparkContext(conf); + SQLContext jsql = new SQLContext(sc); + + // $example on$ + DataFrame data = jsql.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + // $example off$ + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java new file mode 100644 index 0000000000..3407c25c83 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +// $example off$ + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example on$ +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SQLContext; + +/** + * Java example for simple text document 'Pipeline'. + */ +public class JavaPipelineExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaPipelineExample"); + SparkContext sc = new SparkContext(conf); + SQLContext sqlContext = new SQLContext(sc); + + // $example on$ + // Prepare training documents, which are labeled. + DataFrame training = sqlContext.createDataFrame(Arrays.asList( + new JavaLabeledDocument(0L, "a b c d e spark", 1.0), + new JavaLabeledDocument(1L, "b d", 0.0), + new JavaLabeledDocument(2L, "spark f g h", 1.0), + new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0) + ), JavaLabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new JavaDocument(4L, "spark i j k"), + new JavaDocument(5L, "l m n"), + new JavaDocument(6L, "mapreduce spark"), + new JavaDocument(7L, "apache hadoop") + ), JavaDocument.class); + + // Make predictions on test documents. + DataFrame predictions = model.transform(test); + for (Row r : predictions.select("id", "text", "probability", "prediction").collect()) { + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + + ", prediction=" + r.get(3)); + } + // $example off$ + + sc.stop(); + } +} |