aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-04-13 13:58:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:58:35 -0700
commit781df499836e4216939e0febdcd5f89d30645759 (patch)
treef27ffcec85b7e7661eeefd6b02b2e8aa8df01cf5
parentfcdd69260ec75c180f4d727ff2625ca9bf0bdad7 (diff)
downloadspark-781df499836e4216939e0febdcd5f89d30645759.tar.gz
spark-781df499836e4216939e0febdcd5f89d30645759.tar.bz2
spark-781df499836e4216939e0febdcd5f89d30645759.zip
[SPARK-13089][ML] [Doc] spark.ml Naive Bayes user guide and examples
jira: https://issues.apache.org/jira/browse/SPARK-13089 Add section in ml-classification.md for NaiveBayes DataFrame-based API, plus example code (using include_example to clip code from examples/ folder files). Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11015 from hhbyyh/naiveBayesDoc.
-rw-r--r--docs/ml-classification-regression.md34
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java64
-rw-r--r--examples/src/main/python/ml/naive_bayes_example.py53
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala58
4 files changed, 209 insertions, 0 deletions
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index 45155c8ad1..eaf4f6d843 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe
</div>
</div>
+## Naive Bayes
+
+[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
+probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
+assumptions between the features. The spark.ml implementation currently supports both [multinomial
+naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
+and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).
+
+**Example**
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+
+Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details.
+
+{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %}
+</div>
+
+<div data-lang="java" markdown="1">
+
+Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details.
+
+{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %}
+</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details.
+
+{% include_example python/ml/naive_bayes_example.py %}
+</div>
+</div>
+
# Regression
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
new file mode 100644
index 0000000000..41d7ad75b9
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+// $example on$
+import org.apache.spark.ml.classification.NaiveBayes;
+import org.apache.spark.ml.classification.NaiveBayesModel;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+// $example off$
+
+/**
+ * An example for Naive Bayes Classification.
+ */
+public class JavaNaiveBayesExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // $example on$
+ // Load training data
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ // Split the data into train and test
+ Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
+ Dataset<Row> train = splits[0];
+ Dataset<Row> test = splits[1];
+
+ // create the trainer and set its parameters
+ NaiveBayes nb = new NaiveBayes();
+ // train the model
+ NaiveBayesModel model = nb.fit(train);
+ // compute precision on the test set
+ Dataset<Row> result = model.transform(test);
+ Dataset<Row> predictionAndLabels = result.select("prediction", "label");
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision");
+ System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
+ // $example off$
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py
new file mode 100644
index 0000000000..db8fbea9bf
--- /dev/null
+++ b/examples/src/main/python/ml/naive_bayes_example.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.classification import NaiveBayes
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+# $example off$
+
+if __name__ == "__main__":
+
+ sc = SparkContext(appName="naive_bayes_example")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ # Load training data
+ data = sqlContext.read.format("libsvm") \
+ .load("data/mllib/sample_libsvm_data.txt")
+ # Split the data into train and test
+ splits = data.randomSplit([0.6, 0.4], 1234)
+ train = splits[0]
+ test = splits[1]
+
+ # create the trainer and set its parameters
+ nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
+
+ # train the model
+ model = nb.fit(train)
+ # compute precision on the test set
+ result = model.transform(test)
+ predictionAndLabels = result.select("prediction", "label")
+ evaluator = MulticlassClassificationEvaluator(metricName="precision")
+ print("Precision:" + str(evaluator.evaluate(predictionAndLabels)))
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
new file mode 100644
index 0000000000..5ea1270c97
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.classification.{NaiveBayes}
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
+// $example off$
+import org.apache.spark.sql.SQLContext
+
+object NaiveBayesExample {
+ def main(args: Array[String]): Unit = {
+ val conf = new SparkConf().setAppName("NaiveBayesExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ // $example on$
+ // Load the data stored in LIBSVM format as a DataFrame.
+ val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+
+ // Split the data into training and test sets (30% held out for testing)
+ val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
+
+ // Train a NaiveBayes model.
+ val model = new NaiveBayes()
+ .fit(trainingData)
+
+ // Select example rows to display.
+ val predictions = model.transform(testData)
+ predictions.show()
+
+ // Select (prediction, true label) and compute test error
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setLabelCol("label")
+ .setPredictionCol("prediction")
+ .setMetricName("precision")
+ val precision = evaluator.evaluate(predictions)
+ println("Precision:" + precision)
+ // $example off$
+ }
+}
+// scalastyle:on println