aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-17 13:55:17 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-17 13:55:17 -0700
commit9974642870404381fa425fadb966c6dd3ac4a94f (patch)
tree560b3658f6ce0276215ab2e2487b3231b75c4c7f
parent806c579f43ce66ac1398200cbc773fa3b69b5cb6 (diff)
downloadspark-9974642870404381fa425fadb966c6dd3ac4a94f.tar.gz
spark-9974642870404381fa425fadb966c6dd3ac4a94f.tar.bz2
spark-9974642870404381fa425fadb966c6dd3ac4a94f.zip
[SPARK-8600] [ML] Naive Bayes API for spark.ml Pipelines
Naive Bayes API for spark.ml Pipelines Author: Yanbo Liang <ybliang8@gmail.com> Closes #7284 from yanboliang/spark-8600 and squashes the following commits: bc890f7 [Yanbo Liang] remove labels valid check c3de687 [Yanbo Liang] remove labels from ml.NaiveBayesModel a2b3088 [Yanbo Liang] address comments 3220b82 [Yanbo Liang] trigger jenkins 3018a41 [Yanbo Liang] address comments 208e166 [Yanbo Liang] Naive Bayes API for spark.ml Pipelines
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala178
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java98
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala116
5 files changed, 400 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
new file mode 100644
index 0000000000..1f547e4a98
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
+import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+/**
+ * Params for Naive Bayes Classifiers.
+ */
+private[ml] trait NaiveBayesParams extends PredictorParams {
+
+ /**
+ * The smoothing parameter.
+ * (default = 1.0).
+ * @group param
+ */
+ final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
+ ParamValidators.gtEq(0))
+
+ /** @group getParam */
+ final def getLambda: Double = $(lambda)
+
+ /**
+ * The model type which is a string (case-sensitive).
+ * Supported options: "multinomial" and "bernoulli".
+ * (default = multinomial)
+ * @group param
+ */
+ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
+ "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray))
+
+ /** @group getParam */
+ final def getModelType: String = $(modelType)
+}
+
+/**
+ * Naive Bayes Classifiers.
+ * It supports both Multinomial NB
+ * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]])
+ * which can handle finitely supported discrete data. For example, by converting documents into
+ * TF-IDF vectors, it can be used for document classification. By making every vector a
+ * binary (0/1) data, it can also be used as Bernoulli NB
+ * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
+ * The input feature values must be nonnegative.
+ */
+class NaiveBayes(override val uid: String)
+ extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+ with NaiveBayesParams {
+
+ def this() = this(Identifiable.randomUID("nb"))
+
+ /**
+ * Set the smoothing parameter.
+ * Default is 1.0.
+ * @group setParam
+ */
+ def setLambda(value: Double): this.type = set(lambda, value)
+ setDefault(lambda -> 1.0)
+
+ /**
+ * Set the model type using a string (case-sensitive).
+ * Supported options: "multinomial" and "bernoulli".
+ * Default is "multinomial"
+ */
+ def setModelType(value: String): this.type = set(modelType, value)
+ setDefault(modelType -> OldNaiveBayes.Multinomial)
+
+ override protected def train(dataset: DataFrame): NaiveBayesModel = {
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
+ NaiveBayesModel.fromOld(oldModel, this)
+ }
+
+ override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
+}
+
+/**
+ * Model produced by [[NaiveBayes]]
+ */
+class NaiveBayesModel private[ml] (
+ override val uid: String,
+ val pi: Vector,
+ val theta: Matrix)
+ extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+
+ import OldNaiveBayes.{Bernoulli, Multinomial}
+
+ /**
+ * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
+ * This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
+ * application of this condition (in predict function).
+ */
+ private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
+ case Multinomial => (None, None)
+ case Bernoulli =>
+ val negTheta = theta.map(value => math.log(1.0 - math.exp(value)))
+ val ones = new DenseVector(Array.fill(theta.numCols){1.0})
+ val thetaMinusNegTheta = theta.map { value =>
+ value - math.log(1.0 - math.exp(value))
+ }
+ (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ }
+
+ override protected def predict(features: Vector): Double = {
+ $(modelType) match {
+ case Multinomial =>
+ val prob = theta.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ prob.argmax
+ case Bernoulli =>
+ features.foreachActive{ (index, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
+ }
+ }
+ val prob = thetaMinusNegTheta.get.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob.argmax
+ case _ =>
+ // This should never happen.
+ throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
+ }
+ }
+
+ override def copy(extra: ParamMap): NaiveBayesModel = {
+ copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
+ }
+
+ override def toString: String = {
+ s"NaiveBayesModel with ${pi.size} classes"
+ }
+
+}
+
+private[ml] object NaiveBayesModel {
+
+ /** Convert a model from the old API */
+ def fromOld(
+ oldModel: OldNaiveBayesModel,
+ parent: NaiveBayes): NaiveBayesModel = {
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
+ val labels = Vectors.dense(oldModel.labels)
+ val pi = Vectors.dense(oldModel.pi)
+ val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length,
+ oldModel.theta.flatten, true)
+ new NaiveBayesModel(uid, pi, theta)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 9e379d7d74..8cf4e15efe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* where D is number of features
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/
-class NaiveBayesModel private[mllib] (
+class NaiveBayesModel private[spark] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
@@ -382,7 +382,7 @@ class NaiveBayes private (
BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
- ).collect()
+ ).collect().sortBy(_._1)
val numLabels = aggregated.length
var numDocuments = 0L
@@ -425,13 +425,13 @@ class NaiveBayes private (
object NaiveBayes {
/** String name for multinomial model type. */
- private[classification] val Multinomial: String = "multinomial"
+ private[spark] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
- private[classification] val Bernoulli: String = "bernoulli"
+ private[spark] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */
- private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0df0766340..55da0e094d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -98,7 +98,7 @@ sealed trait Matrix extends Serializable {
/** Map the values of this matrix using a function. Generates a new matrix. Performs the
* function on only the backing array. For example, an operation such as addition or
* subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
- private[mllib] def map(f: Double => Double): Matrix
+ private[spark] def map(f: Double => Double): Matrix
/** Update all the values of this matrix using the function f. Performed in-place on the
* backing array. For example, an operation such as addition or subtraction will only be
@@ -289,7 +289,7 @@ class DenseMatrix(
override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
- private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
+ private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f),
isTransposed)
private[mllib] def update(f: Double => Double): DenseMatrix = {
@@ -555,7 +555,7 @@ class SparseMatrix(
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
}
- private[mllib] def map(f: Double => Double) =
+ private[spark] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed)
private[mllib] def update(f: Double => Double): SparseMatrix = {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..09a9fba0c1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaNaiveBayesSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ public void validatePrediction(DataFrame predictionAndLabels) {
+ for (Row r : predictionAndLabels.collect()) {
+ double prediction = r.getAs(0);
+ double label = r.getAs(1);
+ assert(prediction == label);
+ }
+ }
+
+ @Test
+ public void naiveBayesDefaultParams() {
+ NaiveBayes nb = new NaiveBayes();
+ assert(nb.getLabelCol() == "label");
+ assert(nb.getFeaturesCol() == "features");
+ assert(nb.getPredictionCol() == "prediction");
+ assert(nb.getLambda() == 1.0);
+ assert(nb.getModelType() == "multinomial");
+ }
+
+ @Test
+ public void testNaiveBayes() {
+ JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
+ RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
+ RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
+ ));
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+
+ DataFrame dataset = jsql.createDataFrame(jrdd, schema);
+ NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+ NaiveBayesModel model = nb.fit(dataset);
+
+ DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ validatePrediction(predictionAndLabels);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
new file mode 100644
index 0000000000..76381a2741
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.classification.NaiveBayesSuite._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
+
+class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ def validatePrediction(predictionAndLabels: DataFrame): Unit = {
+ val numOfErrorPredictions = predictionAndLabels.collect().count {
+ case Row(prediction: Double, label: Double) =>
+ prediction != label
+ }
+ // At least 80% of the predictions should be on.
+ assert(numOfErrorPredictions < predictionAndLabels.count() / 5)
+ }
+
+ def validateModelFit(
+ piData: Vector,
+ thetaData: Matrix,
+ model: NaiveBayesModel): Unit = {
+ assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
+ Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
+ assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new NaiveBayes)
+ val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
+ theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
+ ParamsSuite.checkParams(model)
+ }
+
+ test("naive bayes: default params") {
+ val nb = new NaiveBayes
+ assert(nb.getLabelCol === "label")
+ assert(nb.getFeaturesCol === "features")
+ assert(nb.getPredictionCol === "prediction")
+ assert(nb.getLambda === 1.0)
+ assert(nb.getModelType === "multinomial")
+ }
+
+ test("Naive Bayes Multinomial") {
+ val nPoints = 1000
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0
+ Array(0.10, 0.70, 0.10, 0.10), // label 1
+ Array(0.10, 0.10, 0.70, 0.10) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+
+ val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 42, "multinomial"))
+ val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
+ val model = nb.fit(testDataset)
+
+ validateModelFit(pi, theta, model)
+ assert(model.hasParent)
+
+ val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 17, "multinomial"))
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+
+ validatePrediction(predictionAndLabels)
+ }
+
+ test("Naive Bayes Bernoulli") {
+ val nPoints = 10000
+ val piArray = Array(0.5, 0.3, 0.2).map(math.log)
+ val thetaArray = Array(
+ Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
+ Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1
+ Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2
+ ).map(_.map(math.log))
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 12, thetaArray.flatten, true)
+
+ val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 45, "bernoulli"))
+ val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
+ val model = nb.fit(testDataset)
+
+ validateModelFit(pi, theta, model)
+ assert(model.hasParent)
+
+ val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
+ piArray, thetaArray, nPoints, 20, "bernoulli"))
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+
+ validatePrediction(predictionAndLabels)
+ }
+}