aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java93
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala86
2 files changed, 179 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
new file mode 100644
index 0000000000..22ba68d8c3
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineModel;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.Tokenizer;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.SparkConf;
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
+ * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
+ * this example {@link SimpleTextClassificationPipeline}. Run with
+ * <pre>
+ * bin/run-example ml.JavaSimpleTextClassificationPipeline
+ * </pre>
+ */
+public class JavaSimpleTextClassificationPipeline {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ JavaSQLContext jsql = new JavaSQLContext(jsc);
+
+ // Prepare training documents, which are labeled.
+ List<LabeledDocument> localTraining = Lists.newArrayList(
+ new LabeledDocument(0L, "a b c d e spark", 1.0),
+ new LabeledDocument(1L, "b d", 0.0),
+ new LabeledDocument(2L, "spark f g h", 1.0),
+ new LabeledDocument(3L, "hadoop mapreduce", 0.0));
+ JavaSchemaRDD training =
+ jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words");
+ HashingTF hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol())
+ .setOutputCol("features");
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01);
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
+
+ // Fit the pipeline to training documents.
+ PipelineModel model = pipeline.fit(training);
+
+ // Prepare test documents, which are unlabeled.
+ List<Document> localTest = Lists.newArrayList(
+ new Document(4L, "spark i j k"),
+ new Document(5L, "l m n"),
+ new Document(6L, "mapreduce spark"),
+ new Document(7L, "apache hadoop"));
+ JavaSchemaRDD test =
+ jsql.applySchema(jsc.parallelize(localTest), Document.class);
+
+ // Make predictions on test documents.
+ model.transform(test).registerAsTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ for (Row r: predictions.collect()) {
+ System.out.println(r);
+ }
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
new file mode 100644
index 0000000000..ee7897d906
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.sql.SQLContext
+
+@BeanInfo
+case class LabeledDocument(id: Long, text: String, label: Double)
+
+@BeanInfo
+case class Document(id: Long, text: String)
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. This is to show
+ * how to create and configure an ML pipeline. Run with
+ * {{{
+ * bin/run-example ml.SimpleTextClassificationPipeline
+ * }}}
+ */
+object SimpleTextClassificationPipeline {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Prepare training documents, which are labeled.
+ val training = sparkContext.parallelize(Seq(
+ LabeledDocument(0L, "a b c d e spark", 1.0),
+ LabeledDocument(1L, "b d", 0.0),
+ LabeledDocument(2L, "spark f g h", 1.0),
+ LabeledDocument(3L, "hadoop mapreduce", 0.0)))
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ val tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words")
+ val hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol)
+ .setOutputCol("features")
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01)
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, hashingTF, lr))
+
+ // Fit the pipeline to training documents.
+ val model = pipeline.fit(training)
+
+ // Prepare test documents, which are unlabeled.
+ val test = sparkContext.parallelize(Seq(
+ Document(4L, "spark i j k"),
+ Document(5L, "l m n"),
+ Document(6L, "mapreduce spark"),
+ Document(7L, "apache hadoop")))
+
+ // Make predictions on test documents.
+ model.transform(test)
+ .select('id, 'text, 'score, 'prediction)
+ .collect()
+ .foreach(println)
+ }
+}