aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-17 13:55:17 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-17 13:55:17 -0700
commit9974642870404381fa425fadb966c6dd3ac4a94f (patch)
tree560b3658f6ce0276215ab2e2487b3231b75c4c7f /mllib/src/test/java
parent806c579f43ce66ac1398200cbc773fa3b69b5cb6 (diff)
downloadspark-9974642870404381fa425fadb966c6dd3ac4a94f.tar.gz
spark-9974642870404381fa425fadb966c6dd3ac4a94f.tar.bz2
spark-9974642870404381fa425fadb966c6dd3ac4a94f.zip
[SPARK-8600] [ML] Naive Bayes API for spark.ml Pipelines
Naive Bayes API for spark.ml Pipelines Author: Yanbo Liang <ybliang8@gmail.com> Closes #7284 from yanboliang/spark-8600 and squashes the following commits: bc890f7 [Yanbo Liang] remove labels valid check c3de687 [Yanbo Liang] remove labels from ml.NaiveBayesModel a2b3088 [Yanbo Liang] address comments 3220b82 [Yanbo Liang] trigger jenkins 3018a41 [Yanbo Liang] address comments 208e166 [Yanbo Liang] Naive Bayes API for spark.ml Pipelines
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java98
1 files changed, 98 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..09a9fba0c1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaNaiveBayesSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ public void validatePrediction(DataFrame predictionAndLabels) {
+ for (Row r : predictionAndLabels.collect()) {
+ double prediction = r.getAs(0);
+ double label = r.getAs(1);
+ assert(prediction == label);
+ }
+ }
+
+ @Test
+ public void naiveBayesDefaultParams() {
+ NaiveBayes nb = new NaiveBayes();
+ assert(nb.getLabelCol() == "label");
+ assert(nb.getFeaturesCol() == "features");
+ assert(nb.getPredictionCol() == "prediction");
+ assert(nb.getLambda() == 1.0);
+ assert(nb.getModelType() == "multinomial");
+ }
+
+ @Test
+ public void testNaiveBayes() {
+ JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
+ ));
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+
+ DataFrame dataset = jsql.createDataFrame(jrdd, schema);
+ NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+ NaiveBayesModel model = nb.fit(dataset);
+
+ DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ validatePrediction(predictionAndLabels);
+ }
+}