aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-30 23:03:48 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-30 23:03:48 -0700
commit69b62f76fced18efa35a107c9be4bc22eba72878 (patch)
tree9cef7ff52d64a096694765badf01e4ea7352d881
parent4e5919bfb47a58bcbda90ae01c1bed2128ded983 (diff)
downloadspark-69b62f76fced18efa35a107c9be4bc22eba72878.tar.gz
spark-69b62f76fced18efa35a107c9be4bc22eba72878.tar.bz2
spark-69b62f76fced18efa35a107c9be4bc22eba72878.zip
[SPARK-9214] [ML] [PySpark] support ml.NaiveBayes for Python
support ml.NaiveBayes for Python Author: Yanbo Liang <ybliang8@gmail.com> Closes #7568 from yanboliang/spark-9214 and squashes the following commits: 5ee3fd6 [Yanbo Liang] fix typos 3ecd046 [Yanbo Liang] fix typos f9c94d1 [Yanbo Liang] change lambda_ to smoothing and fix other issues 180452a [Yanbo Liang] fix typos 7dda1f4 [Yanbo Liang] support ml.NaiveBayes for Python
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala6
-rw-r--r--python/pyspark/ml/classification.py116
4 files changed, 125 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 1f547e4a98..5be35fe209 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* (default = 1.0).
* @group param
*/
- final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
+ final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.",
ParamValidators.gtEq(0))
/** @group getParam */
- final def getLambda: Double = $(lambda)
+ final def getSmoothing: Double = $(smoothing)
/**
* The model type which is a string (case-sensitive).
@@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String)
* Default is 1.0.
* @group setParam
*/
- def setLambda(value: Double): this.type = set(lambda, value)
- setDefault(lambda -> 1.0)
+ def setSmoothing(value: Double): this.type = set(smoothing, value)
+ setDefault(smoothing -> 1.0)
/**
* Set the model type using a string (case-sensitive).
@@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String)
override protected def train(dataset: DataFrame): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
- val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
+ val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 09a9fba0c1..a700c9cddb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -68,7 +68,7 @@ public class JavaNaiveBayesSuite implements Serializable {
assert(nb.getLabelCol() == "label");
assert(nb.getFeaturesCol() == "features");
assert(nb.getPredictionCol() == "prediction");
- assert(nb.getLambda() == 1.0);
+ assert(nb.getSmoothing() == 1.0);
assert(nb.getModelType() == "multinomial");
}
@@ -89,7 +89,7 @@ public class JavaNaiveBayesSuite implements Serializable {
});
DataFrame dataset = jsql.createDataFrame(jrdd, schema);
- NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+ NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 76381a2741..264bde3703 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -58,7 +58,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(nb.getLabelCol === "label")
assert(nb.getFeaturesCol === "features")
assert(nb.getPredictionCol === "prediction")
- assert(nb.getLambda === 1.0)
+ assert(nb.getSmoothing === 1.0)
assert(nb.getModelType === "multinomial")
}
@@ -75,7 +75,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 42, "multinomial"))
- val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
validateModelFit(pi, theta, model)
@@ -101,7 +101,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 45, "bernoulli"))
- val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
validateModelFit(pi, theta, model)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5a82bc286d..93ffcd4094 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,7 +25,8 @@ from pyspark.mllib.common import inherit_doc
__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
- 'RandomForestClassifier', 'RandomForestClassificationModel']
+ 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes',
+ 'NaiveBayesModel']
@inherit_doc
@@ -576,6 +577,119 @@ class GBTClassificationModel(TreeEnsembleModels):
"""
+@inherit_doc
+class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+ """
+ Naive Bayes Classifiers.
+
+ >>> from pyspark.sql import Row
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame([
+ ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
+ ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
+ ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
+ >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
+ >>> model = nb.fit(df)
+ >>> model.pi
+ DenseVector([-0.51..., -0.91...])
+ >>> model.theta
+ DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1)
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
+ >>> model.transform(test0).head().prediction
+ 1.0
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
+ >>> model.transform(test1).head().prediction
+ 1.0
+ """
+
+ # a placeholder to make it appear in the generated doc
+ smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
+ "default is 1.0")
+ modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
+ "(case-sensitive). Supported options: multinomial (default) and bernoulli.")
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ smoothing=1.0, modelType="multinomial"):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ smoothing=1.0, modelType="multinomial")
+ """
+ super(NaiveBayes, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.NaiveBayes", self.uid)
+ #: param for the smoothing parameter.
+ self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " +
+ "default is 1.0")
+ #: param for the model type.
+ self.modelType = Param(self, "modelType", "The model type which is a string " +
+ "(case-sensitive). Supported options: multinomial (default) " +
+ "and bernoulli.")
+ self._setDefault(smoothing=1.0, modelType="multinomial")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ smoothing=1.0, modelType="multinomial"):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ smoothing=1.0, modelType="multinomial")
+ Sets params for Naive Bayes.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return NaiveBayesModel(java_model)
+
+ def setSmoothing(self, value):
+ """
+ Sets the value of :py:attr:`smoothing`.
+ """
+ self._paramMap[self.smoothing] = value
+ return self
+
+ def getSmoothing(self):
+ """
+ Gets the value of smoothing or its default value.
+ """
+ return self.getOrDefault(self.smoothing)
+
+ def setModelType(self, value):
+ """
+ Sets the value of :py:attr:`modelType`.
+ """
+ self._paramMap[self.modelType] = value
+ return self
+
+ def getModelType(self):
+ """
+ Gets the value of modelType or its default value.
+ """
+ return self.getOrDefault(self.modelType)
+
+
+class NaiveBayesModel(JavaModel):
+ """
+ Model fitted by NaiveBayes.
+ """
+
+ @property
+ def pi(self):
+ """
+ log of class priors.
+ """
+ return self._call_java("pi")
+
+ @property
+ def theta(self):
+ """
+ log of class conditional probabilities.
+ """
+ return self._call_java("theta")
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext