aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-11-12 06:13:22 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-12 06:13:22 -0800
commit22cb3a060a440205281b71686637679645454ca6 (patch)
tree008f83b11aa6452f384264aad5324e10d72a1f48 /mllib
parentbc41d997ea287080f549219722b6d9049adef4e2 (diff)
downloadspark-22cb3a060a440205281b71686637679645454ca6.tar.gz
spark-22cb3a060a440205281b71686637679645454ca6.tar.bz2
spark-22cb3a060a440205281b71686637679645454ca6.zip
[SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes
## What changes were proposed in this pull request? * Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it. * Avoid capturing the outer object for ```modelType```. * Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15826 from yanboliang/spark-14077-2.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala6
2 files changed, 39 insertions, 39 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index b03a07a6bc..f1a7676c74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") (
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams with DefaultParamsWritable {
- import NaiveBayes.{Bernoulli, Multinomial}
+ import NaiveBayes._
@Since("1.5.0")
def this() = this(Identifiable.randomUID("nb"))
@@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") (
@Since("2.1.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+ trainWithLabelCheck(dataset, positiveLabel = true)
+ }
+
/**
* ml assumes input labels in range [0, numClasses). But this implementation
* is also called by mllib NaiveBayes which allows other kinds of input labels
- * such as {-1, +1}. Here we use this parameter to switch between different processing logic.
- * It should be removed when we remove mllib NaiveBayes.
+ * such as {-1, +1}. `positiveLabel` is used to determine whether the label
+ * should be checked and it should be removed when we remove mllib NaiveBayes.
*/
- private[spark] var isML: Boolean = true
-
- private[spark] def setIsML(isML: Boolean): this.type = {
- this.isML = isML
- this
- }
-
- override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
- if (isML) {
+ private[spark] def trainWithLabelCheck(
+ dataset: Dataset[_],
+ positiveLabel: Boolean): NaiveBayesModel = {
+ if (positiveLabel) {
val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") (
}
}
- val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
- val values = v match {
- case sv: SparseVector => sv.values
- case dv: DenseVector => dv.values
- }
-
- require(values.forall(_ >= 0.0),
- s"Naive Bayes requires nonnegative feature values but found $v.")
- }
-
- val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
- val values = v match {
- case sv: SparseVector => sv.values
- case dv: DenseVector => dv.values
- }
-
- require(values.forall(v => v == 0.0 || v == 1.0),
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
- }
-
+ val modelTypeValue = $(modelType)
val requireValues: Vector => Unit = {
- $(modelType) match {
+ modelTypeValue match {
case Multinomial =>
requireNonnegativeValues
case Bernoulli =>
@@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") (
@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/** String name for multinomial model type. */
- private[spark] val Multinomial: String = "multinomial"
+ private[classification] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
- private[spark] val Bernoulli: String = "bernoulli"
+ private[classification] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */
- private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+
+ private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
+ val values = v match {
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
+ }
+
+ require(values.forall(_ >= 0.0),
+ s"Naive Bayes requires nonnegative feature values but found $v.")
+ }
+
+ private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
+ val values = v match {
+ case sv: SparseVector => sv.values
+ case dv: DenseVector => dv.values
+ }
+
+ require(values.forall(v => v == 0.0 || v == 1.0),
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
+ }
@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 33561be4b5..767d056861 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,12 +364,12 @@ class NaiveBayes private (
val nb = new NewNaiveBayes()
.setModelType(modelType)
.setSmoothing(lambda)
- .setIsML(false)
val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
.toDF("label", "features")
- val newModel = nb.fit(dataset)
+ // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false.
+ val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false)
val pi = newModel.pi.toArray
val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0)
@@ -378,7 +378,7 @@ class NaiveBayes private (
theta(i)(j) = v
}
- require(newModel.oldLabels != null,
+ assert(newModel.oldLabels != null,
"The underlying ML NaiveBayes training does not produce labels.")
new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
}