From 1fad5596885aab8b32d2307c0edecbae50d5bd7a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Thu, 29 Sep 2016 23:55:42 -0700 Subject: [SPARK-14077][ML] Refactor NaiveBayes to support weighted instances ## What changes were proposed in this pull request? 1,support weighted data 2,use dataset/dataframe instead of rdd 3,make mllib as a wrapper to call ml ## How was this patch tested? local manual tests in spark-shell unit tests Author: Zheng RuiFeng Closes #12819 from zhengruifeng/weighted_nb. --- .../spark/ml/classification/NaiveBayes.scala | 154 ++++++++++++++++----- .../spark/mllib/classification/NaiveBayes.scala | 99 ++++--------- .../spark/ml/classification/NaiveBayesSuite.scala | 50 ++++++- 3 files changed, 191 insertions(+), 112 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index f939a1c680..0d652aa4c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -19,23 +19,20 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} -import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} -import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType /** * Params for Naive Bayes Classifiers. */ -private[ml] trait NaiveBayesParams extends PredictorParams { +private[ml] trait NaiveBayesParams extends PredictorParams with HasWeightCol { /** * The smoothing parameter. @@ -56,7 +53,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { */ final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " + "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.", - ParamValidators.inArray[String](OldNaiveBayes.supportedModelTypes.toArray)) + ParamValidators.inArray[String](NaiveBayes.supportedModelTypes.toArray)) /** @group getParam */ final def getModelType: String = $(modelType) @@ -64,7 +61,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { /** * Naive Bayes Classifiers. - * It supports both Multinomial NB + * It supports Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) * which can handle finitely supported discrete data. For example, by converting documents into * TF-IDF vectors, it can be used for document classification. By making every vector a @@ -78,6 +75,8 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { + import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -98,7 +97,17 @@ class NaiveBayes @Since("1.5.0") ( */ @Since("1.5.0") def setModelType(value: String): this.type = set(modelType, value) - setDefault(modelType -> OldNaiveBayes.Multinomial) + setDefault(modelType -> NaiveBayes.Multinomial) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.1.0") + def setWeightCol(value: String): this.type = set(weightCol, value) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val numClasses = getNumClasses(dataset) @@ -109,10 +118,89 @@ class NaiveBayes @Since("1.5.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val oldDataset: RDD[OldLabeledPoint] = - extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) - val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) - NaiveBayesModel.fromOld(oldModel, this) + val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size + + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } + + val requireValues: Vector => Unit = { + $(modelType) match { + case Multinomial => + requireNonnegativeValues + case Bernoulli => + requireZeroOneBernoulliValues + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + } + + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + + // Aggregates term frequencies per label. + // TODO: Calling aggregateByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. + val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) + }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( + seqOp = { + case ((weightSum: Double, featureSum: DenseVector), (weight, features)) => + requireValues(features) + BLAS.axpy(weight, features, featureSum) + (weightSum + weight, featureSum) + }, + combOp = { + case ((weightSum1, featureSum1), (weightSum2, featureSum2)) => + BLAS.axpy(1.0, featureSum2, featureSum1) + (weightSum1 + weightSum2, featureSum1) + }).collect().sortBy(_._1) + + val numLabels = aggregated.length + val numDocuments = aggregated.map(_._2._1).sum + + val piArray = Array.fill[Double](numLabels)(0.0) + val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + + val lambda = $(smoothing) + val piLogDenom = math.log(numDocuments + numLabels * lambda) + var i = 0 + aggregated.foreach { case (label, (n, sumTermFreqs)) => + piArray(i) = math.log(n + lambda) - piLogDenom + val thetaLogDenom = $(modelType) match { + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") + } + var j = 0 + while (j < numFeatures) { + thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + j += 1 + } + i += 1 + } + + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + new NaiveBayesModel(uid, pi, theta) } @Since("1.5.0") @@ -121,6 +209,14 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { + /** String name for multinomial model type. */ + private[spark] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[spark] val Bernoulli: String = "bernoulli" + + /* Set of modelTypes that NaiveBayes supports */ + private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) @@ -140,7 +236,7 @@ class NaiveBayesModel private[ml] ( extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { - import OldNaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes.{Bernoulli, Multinomial} /** * Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. @@ -175,10 +271,8 @@ class NaiveBayesModel private[ml] ( private def bernoulliCalculation(features: Vector) = { features.foreachActive((_, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") - } + require(value == 0.0 || value == 1.0, + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") ) val prob = thetaMinusNegTheta.get.multiply(features) BLAS.axpy(1.0, pi, prob) @@ -238,18 +332,6 @@ class NaiveBayesModel private[ml] ( @Since("1.6.0") object NaiveBayesModel extends MLReadable[NaiveBayesModel] { - /** Convert a model from the old API */ - private[ml] def fromOld( - oldModel: OldNaiveBayesModel, - parent: NaiveBayes): NaiveBayesModel = { - val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") - val labels = Vectors.dense(oldModel.labels) - val pi = Vectors.dense(oldModel.pi) - val theta = new DenseMatrix(oldModel.labels.length, oldModel.theta(0).length, - oldModel.theta.flatten, true) - new NaiveBayesModel(uid, pi, theta) - } - @Since("1.6.0") override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader @@ -280,11 +362,9 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") - val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") - .select("pi", "theta") - .head() + val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() + val pi = data.getAs[Vector](0) + val theta = data.getAs[Matrix](1) val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 593a86f69a..32d6968a4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -27,7 +27,8 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} +import org.apache.spark.ml.classification.{NaiveBayes => NewNaiveBayes} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -311,8 +312,6 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { - import NaiveBayes.{Bernoulli, Multinomial} - @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) @@ -355,79 +354,33 @@ class NaiveBayes private ( */ @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(_ >= 0.0)) { - throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") - } - } + val spark = SparkSession + .builder() + .sparkContext(data.context) + .getOrCreate() - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - if (!values.forall(v => v == 0.0 || v == 1.0)) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - } + import spark.implicits._ - // Aggregates term frequencies per label. - // TODO: Calling combineByKey and collect creates two stages, we can implement something - // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( - createCombiner = (v: Vector) => { - if (modelType == Bernoulli) { - requireZeroOneBernoulliValues(v) - } else { - requireNonnegativeValues(v) - } - (1L, v.copy.toDense) - }, - mergeValue = (c: (Long, DenseVector), v: Vector) => { - requireNonnegativeValues(v) - BLAS.axpy(1.0, v, c._2) - (c._1 + 1L, c._2) - }, - mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { - BLAS.axpy(1.0, c2._2, c1._2) - (c1._1 + c2._1, c1._2) - } - ).collect().sortBy(_._1) + val nb = new NewNaiveBayes() + .setModelType(modelType) + .setSmoothing(lambda) - val numLabels = aggregated.length - var numDocuments = 0L - aggregated.foreach { case (_, (n, _)) => - numDocuments += n - } - val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } - - val labels = new Array[Double](numLabels) - val pi = new Array[Double](numLabels) - val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) - - val piLogDenom = math.log(numDocuments + numLabels * lambda) - var i = 0 - aggregated.foreach { case (label, (n, sumTermFreqs)) => - labels(i) = label - pi(i) = math.log(n + lambda) - piLogDenom - val thetaLogDenom = modelType match { - case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) - case Bernoulli => math.log(n + 2.0 * lambda) - case _ => - // This should never happen. - throw new UnknownError(s"Invalid modelType: $modelType.") - } - var j = 0 - while (j < numFeatures) { - theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom - j += 1 - } - i += 1 + val labels = data.map(_.label).distinct().collect().sorted + + // Input labels for [[org.apache.spark.ml.classification.NaiveBayes]] must be + // in range [0, numClasses). + val dataset = data.map { + case LabeledPoint(label, features) => + (labels.indexOf(label).toDouble, features.asML) + }.toDF("label", "features") + + val newModel = nb.fit(dataset) + + val pi = newModel.pi.toArray + val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) + newModel.theta.foreachActive { + case (i, j, v) => + theta(i)(j) = v } new NaiveBayesModel(labels, pi, theta, modelType) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9909932428..597428d036 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -23,13 +23,13 @@ import breeze.linalg.{DenseVector => BDV, Vector => BV} import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.ml.classification.NaiveBayesSuite._ -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -152,6 +152,52 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "multinomial") } + test("Naive Bayes Multinomial with weighted samples") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("multinomial") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + + test("Naive Bayes Bernoulli with weighted samples") { + val nPoints = 10000 + val piArray = Array(0.5, 0.3, 0.2).map(math.log) + val thetaArray = Array( + Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 + Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + ).map(_.map(math.log)) + + val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "bernoulli").toDF() + val (overSampledData, weightedData) = + MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, + "label", "features", 42L) + val nb = new NaiveBayes().setModelType("bernoulli") + val unweightedModel = nb.fit(weightedData) + val overSampledModel = nb.fit(overSampledData) + val weightedModel = nb.setWeightCol("weight").fit(weightedData) + assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) + assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) + assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) + assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + } + test("Naive Bayes Bernoulli") { val nPoints = 10000 val piArray = Array(0.5, 0.3, 0.2).map(math.log) -- cgit v1.2.3