aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala199
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala75
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala102
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala27
-rw-r--r--project/MimaExcludes.scala10
7 files changed, 303 insertions, 128 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a460262b87..bd96e8d000 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -29,12 +29,12 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.storage.StorageLevel
/**
@@ -42,7 +42,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasStandardization with HasThreshold {
+ with HasStandardization with HasWeightCol with HasThreshold {
/**
* Set threshold in binary classification, in range [0, 1].
@@ -147,6 +147,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
/**
+ * Class that represents an instance of weighted data point with label and features.
+ *
+ * TODO: Refactor this class to proper place.
+ *
+ * @param label Label for this data point.
+ * @param weight The weight of this instance.
+ * @param features The vector of features for this data point.
+ */
+private[classification] case class Instance(label: Double, weight: Double, features: Vector)
+
+/**
* :: Experimental ::
* Logistic regression.
* Currently, this class only supports binary classification. It will support multiclass
@@ -218,31 +229,42 @@ class LogisticRegression(override val uid: String)
override def getThreshold: Double = super.getThreshold
+ /**
+ * Whether to over-/under-sample training instances according to the given weights in weightCol.
+ * If empty, all instances are treated equally (weight 1.0).
+ * Default is empty, so all instances have weight one.
+ * @group setParam
+ */
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+ setDefault(weightCol -> "")
+
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
override def getThresholds: Array[Double] = super.getThresholds
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
- val instances = extractLabeledPoints(dataset).map {
- case LabeledPoint(label: Double, features: Vector) => (label, features)
+ val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
+
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val (summarizer, labelSummarizer) = instances.treeAggregate(
- (new MultivariateOnlineSummarizer, new MultiClassSummarizer))(
- seqOp = (c, v) => (c, v) match {
- case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer),
- (label: Double, features: Vector)) =>
- (summarizer.add(features), labelSummarizer.add(label))
- },
- combOp = (c1, c2) => (c1, c2) match {
- case ((summarizer1: MultivariateOnlineSummarizer,
- classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer,
- classSummarizer2: MultiClassSummarizer)) =>
- (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2))
- })
+ val (summarizer, labelSummarizer) = {
+ val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+ instance: Instance) =>
+ (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
+
+ val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
+ c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
+ (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+ instances.treeAggregate(
+ new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
+ }
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
@@ -295,7 +317,7 @@ class LogisticRegression(override val uid: String)
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}
- val initialWeightsWithIntercept =
+ val initialCoefficientsWithIntercept =
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
if ($(fitIntercept)) {
@@ -312,14 +334,14 @@ class LogisticRegression(override val uid: String)
b = \log{P(1) / P(0)} = \log{count_1 / count_0}
}}}
*/
- initialWeightsWithIntercept.toArray(numFeatures)
- = math.log(histogram(1).toDouble / histogram(0).toDouble)
+ initialCoefficientsWithIntercept.toArray(numFeatures)
+ = math.log(histogram(1) / histogram(0))
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
- initialWeightsWithIntercept.toBreeze.toDenseVector)
+ initialCoefficientsWithIntercept.toBreeze.toDenseVector)
- val (weights, intercept, objectiveHistory) = {
+ val (coefficients, intercept, objectiveHistory) = {
/*
Note that in Logistic Regression, the objective history (loss + regularization)
is log-likelihood which is invariance under feature standardization. As a result,
@@ -339,28 +361,29 @@ class LogisticRegression(override val uid: String)
}
/*
- The weights are trained in the scaled space; we're converting them back to
+ The coefficients are trained in the scaled space; we're converting them back to
the original space.
Note that the intercept in scaled space and original space is the same;
as a result, no scaling is needed.
*/
- val rawWeights = state.x.toArray.clone()
+ val rawCoefficients = state.x.toArray.clone()
var i = 0
while (i < numFeatures) {
- rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
i += 1
}
if ($(fitIntercept)) {
- (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result())
+ (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
+ arrayBuilder.result())
} else {
- (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result())
+ (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
}
}
if (handlePersistence) instances.unpersist()
- val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
model.transform(dataset),
$(probabilityCol),
@@ -501,22 +524,29 @@ class LogisticRegressionModel private[ml] (
* corresponding joint dataset.
*/
private[classification] class MultiClassSummarizer extends Serializable {
- private val distinctMap = new mutable.HashMap[Int, Long]
+ // The first element of value in distinctMap is the actually number of instances,
+ // and the second element of value is sum of the weights.
+ private val distinctMap = new mutable.HashMap[Int, (Long, Double)]
private var totalInvalidCnt: Long = 0L
/**
* Add a new label into this MultilabelSummarizer, and update the distinct map.
* @param label The label for this data point.
+ * @param weight The weight of this instances.
* @return This MultilabelSummarizer
*/
- def add(label: Double): this.type = {
+ def add(label: Double, weight: Double = 1.0): this.type = {
+ require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+
+ if (weight == 0.0) return this
+
if (label - label.toInt != 0.0 || label < 0) {
totalInvalidCnt += 1
this
}
else {
- val counts: Long = distinctMap.getOrElse(label.toInt, 0L)
- distinctMap.put(label.toInt, counts + 1)
+ val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0))
+ distinctMap.put(label.toInt, (counts + 1L, weightSum + weight))
this
}
}
@@ -537,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable {
}
smallMap.distinctMap.foreach {
case (key, value) =>
- val counts = largeMap.distinctMap.getOrElse(key, 0L)
- largeMap.distinctMap.put(key, counts + value)
+ val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0))
+ largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2))
}
largeMap.totalInvalidCnt += smallMap.totalInvalidCnt
largeMap
@@ -550,13 +580,13 @@ private[classification] class MultiClassSummarizer extends Serializable {
/** @return The number of distinct labels in the input dataset. */
def numClasses: Int = distinctMap.keySet.max + 1
- /** @return The counts of each label in the input dataset. */
- def histogram: Array[Long] = {
- val result = Array.ofDim[Long](numClasses)
+ /** @return The weightSum of each label in the input dataset. */
+ def histogram: Array[Double] = {
+ val result = Array.ofDim[Double](numClasses)
var i = 0
val len = result.length
while (i < len) {
- result(i) = distinctMap.getOrElse(i, 0L)
+ result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2
i += 1
}
result
@@ -565,6 +595,8 @@ private[classification] class MultiClassSummarizer extends Serializable {
/**
* Abstraction for multinomial Logistic Regression Training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
*/
sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
@@ -584,10 +616,10 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Dataframe outputted by the model's `transform` method. */
def predictions: DataFrame
- /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
+ /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the the true label of each sample. */
+ /** Field in "predictions" which gives the the true label of each instance. */
def labelCol: String
}
@@ -597,8 +629,8 @@ sealed trait LogisticRegressionSummary extends Serializable {
* Logistic regression training results.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each sample as a vector.
- * @param labelCol field in "predictions" which gives the true label of each sample.
+ * each instance as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
@@ -617,8 +649,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* Binary Logistic regression results for a given model.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each sample.
- * @param labelCol field in "predictions" which gives the true label of each sample.
+ * each instance.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@@ -687,14 +719,14 @@ class BinaryLogisticRegressionSummary private[classification] (
/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
- * in binary classification for samples in sparse or dense vector in a online fashion.
+ * in binary classification for instances in sparse or dense vector in a online fashion.
*
* Note that multinomial logistic loss is not supported yet!
*
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
- * @param weights The weights/coefficients corresponding to the features.
+ * @param coefficients The coefficients corresponding to the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term.
@@ -702,25 +734,25 @@ class BinaryLogisticRegressionSummary private[classification] (
* @param featuresMean The mean values of the features.
*/
private class LogisticAggregator(
- weights: Vector,
+ coefficients: Vector,
numClasses: Int,
fitIntercept: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double]) extends Serializable {
- private var totalCnt: Long = 0L
+ private var weightSum = 0.0
private var lossSum = 0.0
- private val weightsArray = weights match {
+ private val coefficientsArray = coefficients match {
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
- s"weights only supports dense vector but got type ${weights.getClass}.")
+ s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
}
- private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length
+ private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length
- private val gradientSumArray = Array.ofDim[Double](weightsArray.length)
+ private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length)
/**
* Add a new training data to this LogisticAggregator, and update the loss and gradient
@@ -729,13 +761,17 @@ private class LogisticAggregator(
* @param label The label for this data point.
* @param data The features for one data point in dense/sparse vector format to be added
* into this aggregator.
+ * @param weight The weight for over-/undersamples each of training instance. Default is one.
* @return This LogisticAggregator object.
*/
- def add(label: Double, data: Vector): this.type = {
- require(dim == data.size, s"Dimensions mismatch when adding new sample." +
+ def add(label: Double, data: Vector, weight: Double = 1.0): this.type = {
+ require(dim == data.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $dim but got ${data.size}.")
+ require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
- val localWeightsArray = weightsArray
+ if (weight == 0.0) return this
+
+ val localCoefficientsArray = coefficientsArray
val localGradientSumArray = gradientSumArray
numClasses match {
@@ -745,13 +781,13 @@ private class LogisticAggregator(
var sum = 0.0
data.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
- sum += localWeightsArray(index) * (value / featuresStd(index))
+ sum += localCoefficientsArray(index) * (value / featuresStd(index))
}
}
- sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 }
+ sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 }
}
- val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
+ val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
data.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
@@ -765,15 +801,15 @@ private class LogisticAggregator(
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
- lossSum += MLUtils.log1pExp(margin)
+ lossSum += weight * MLUtils.log1pExp(margin)
} else {
- lossSum += MLUtils.log1pExp(margin) - margin
+ lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
case _ =>
new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " +
"binary classification for now.")
}
- totalCnt += 1
+ weightSum += weight
this
}
@@ -789,8 +825,8 @@ private class LogisticAggregator(
require(dim == other.dim, s"Dimensions mismatch when merging with another " +
s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
- if (other.totalCnt != 0) {
- totalCnt += other.totalCnt
+ if (other.weightSum != 0.0) {
+ weightSum += other.weightSum
lossSum += other.lossSum
var i = 0
@@ -805,13 +841,17 @@ private class LogisticAggregator(
this
}
- def count: Long = totalCnt
-
- def loss: Double = lossSum / totalCnt
+ def loss: Double = {
+ require(weightSum > 0.0, s"The effective number of instances should be " +
+ s"greater than 0.0, but $weightSum.")
+ lossSum / weightSum
+ }
def gradient: Vector = {
+ require(weightSum > 0.0, s"The effective number of instances should be " +
+ s"greater than 0.0, but $weightSum.")
val result = Vectors.dense(gradientSumArray.clone())
- scal(1.0 / totalCnt, result)
+ scal(1.0 / weightSum, result)
result
}
}
@@ -823,7 +863,7 @@ private class LogisticAggregator(
* It's used in Breeze's convex optimization routines.
*/
private class LogisticCostFun(
- data: RDD[(Double, Vector)],
+ data: RDD[Instance],
numClasses: Int,
fitIntercept: Boolean,
standardization: Boolean,
@@ -831,22 +871,23 @@ private class LogisticCostFun(
featuresMean: Array[Double],
regParamL2: Double) extends DiffFunction[BDV[Double]] {
- override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
- val w = Vectors.fromBreeze(weights)
+ val w = Vectors.fromBreeze(coefficients)
- val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept,
- featuresStd, featuresMean))(
- seqOp = (c, v) => (c, v) match {
- case (aggregator, (label, features)) => aggregator.add(label, features)
- },
- combOp = (c1, c2) => (c1, c2) match {
- case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
- })
+ val logisticAggregator = {
+ val seqOp = (c: LogisticAggregator, instance: Instance) =>
+ c.add(instance.label, instance.features, instance.weight)
+ val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
+
+ data.treeAggregate(
+ new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean)
+ )(seqOp, combOp)
+ }
val totalGradientArray = logisticAggregator.gradient.toArray
- // regVal is the sum of weight squares excluding intercept for L2 regularization.
+ // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
val regVal = if (regParamL2 == 0.0) {
0.0
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index e9e99ed1db..8049d51fee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen {
Some("\"rawPrediction\"")),
ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" +
" probabilities. Note: Not all models output well-calibrated probability estimates!" +
- " These probabilities should be treated as confidences, not precise probabilities.",
+ " These probabilities should be treated as confidences, not precise probabilities",
Some("\"probability\"")),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
@@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen {
"options may be added later.",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
- " before fitting the model.", Some("true")),
+ " before fitting the model", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
- " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
+ " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty",
isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 3009217086..aff47fc326 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params {
private[ml] trait HasProbabilityCol extends Params {
/**
- * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
+ * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
* @group param
*/
- final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
+ final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities")
setDefault(probabilityCol, "probability")
@@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params {
private[ml] trait HasStandardization extends Params {
/**
- * Param for whether to standardize the training features before fitting the model..
+ * Param for whether to standardize the training features before fitting the model.
* @group param
*/
- final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.")
+ final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model")
setDefault(standardization, true)
@@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params {
private[ml] trait HasElasticNetParam extends Params {
/**
- * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
+ * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* @group param
*/
- final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
+ final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1))
/** @group getParam */
final def getElasticNetParam: Double = $(elasticNetParam)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 51b713e263..201333c369 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* :: DeveloperApi ::
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
- * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
+ * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector
* format in a online fashion.
*
* Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
* the corresponding joint dataset.
*
- * A numerically stable algorithm is implemented to compute sample mean and variance:
+ * A numerically stable algorithm is implemented to compute the mean and variance of instances:
* Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
* Zero elements (including explicit zero values) are skipped when calling add(),
* to have time complexity O(nnz) instead of O(n) for each column.
+ *
+ * For weighted instances, the unbiased estimation of variance is defined by the reliability
+ * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]].
*/
@Since("1.1.0")
@DeveloperApi
@@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currM2: Array[Double] = _
private var currL1: Array[Double] = _
private var totalCnt: Long = 0
+ private var weightSum: Double = 0.0
+ private var weightSquareSum: Double = 0.0
private var nnz: Array[Double] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
@@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
* @return This MultivariateOnlineSummarizer object.
*/
@Since("1.1.0")
- def add(sample: Vector): this.type = {
+ def add(sample: Vector): this.type = add(sample, 1.0)
+
+ private[spark] def add(instance: Vector, weight: Double): this.type = {
+ require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0")
+ if (weight == 0.0) return this
+
if (n == 0) {
- require(sample.size > 0, s"Vector should have dimension larger than zero.")
- n = sample.size
+ require(instance.size > 0, s"Vector should have dimension larger than zero.")
+ n = instance.size
currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
@@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMin = Array.fill[Double](n)(Double.MaxValue)
}
- require(n == sample.size, s"Dimensions mismatch when adding new sample." +
- s" Expecting $n but got ${sample.size}.")
+ require(n == instance.size, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $n but got ${instance.size}.")
val localCurrMean = currMean
val localCurrM2n = currM2n
@@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localNnz = nnz
val localCurrMax = currMax
val localCurrMin = currMin
- sample.foreachActive { (index, value) =>
+ instance.foreachActive { (index, value) =>
if (value != 0.0) {
if (localCurrMax(index) < value) {
localCurrMax(index) = value
@@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val prevMean = localCurrMean(index)
val diff = value - prevMean
- localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0)
- localCurrM2n(index) += (value - localCurrMean(index)) * diff
- localCurrM2(index) += value * value
- localCurrL1(index) += math.abs(value)
+ localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight)
+ localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
+ localCurrM2(index) += weight * value * value
+ localCurrL1(index) += weight * math.abs(value)
- localNnz(index) += 1.0
+ localNnz(index) += weight
}
}
+ weightSum += weight
+ weightSquareSum += weight * weight
totalCnt += 1
this
}
@@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
def merge(other: MultivariateOnlineSummarizer): this.type = {
- if (this.totalCnt != 0 && other.totalCnt != 0) {
+ if (this.weightSum != 0.0 && other.weightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
+ weightSum += other.weightSum
+ weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = nnz(i)
@@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
nnz(i) = totalNnz
i += 1
}
- } else if (totalCnt == 0 && other.totalCnt != 0) {
+ } else if (weightSum == 0.0 && other.weightSum != 0.0) {
this.n = other.n
this.currMean = other.currMean.clone()
this.currM2n = other.currM2n.clone()
this.currM2 = other.currM2.clone()
this.currL1 = other.currL1.clone()
this.totalCnt = other.totalCnt
+ this.weightSum = other.weightSum
+ this.weightSquareSum = other.weightSquareSum
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
@@ -158,28 +174,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def mean: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
- realMean(i) = currMean(i) * (nnz(i) / totalCnt)
+ realMean(i) = currMean(i) * (nnz(i) / weightSum)
i += 1
}
Vectors.dense(realMean)
}
/**
- * Sample variance of each dimension.
+ * Unbiased estimate of sample variance of each dimension.
*
*/
@Since("1.1.0")
override def variance: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
val realVariance = Array.ofDim[Double](n)
- val denominator = totalCnt - 1.0
+ val denominator = weightSum - (weightSquareSum / weightSum)
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
@@ -187,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) =
- currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
- realVariance(i) /= denominator
+ realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
+ (weightSum - nnz(i)) / weightSum) / denominator
i += 1
}
}
@@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(nnz)
}
@@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def max: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
- if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
+ if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.dense(currMax)
@@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def min: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
- if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
+ if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.dense(currMin)
@@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL2: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
val realMagnitude = Array.ofDim[Double](n)
@@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.2.0")
override def normL1: Vector = {
- require(totalCnt > 0, s"Nothing has been added to this summarizer.")
+ require(weightSum > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(currL1)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index cce39f382f..f5219f9f57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.ml.classification
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
- sqlContext.createDataFrame(
- generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
+ sqlContext.createDataFrame(sc.parallelize(testData, 4))
}
}
@@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol === "prediction")
assert(lr.getRawPredictionCol === "rawPrediction")
assert(lr.getProbabilityCol === "probability")
+ assert(lr.getWeightCol === "")
assert(lr.getFitIntercept)
assert(lr.getStandardization)
val model = lr.fit(dataset)
@@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("MultiClassSummarizer") {
val summarizer1 = (new MultiClassSummarizer)
.add(0.0).add(3.0).add(4.0).add(3.0).add(6.0)
- assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2))
+ assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1))
assert(summarizer1.countInvalid === 0)
assert(summarizer1.numClasses === 7)
val summarizer2 = (new MultiClassSummarizer)
.add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0)
- assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1))
assert(summarizer2.countInvalid === 0)
assert(summarizer2.numClasses === 6)
val summarizer3 = (new MultiClassSummarizer)
.add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0)
- assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2))
+ assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3))
assert(summarizer3.countInvalid === 3)
assert(summarizer3.numClasses === 5)
val summarizer4 = (new MultiClassSummarizer)
.add(3.1).add(4.3).add(2.0).add(1.0).add(3.0)
- assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer4.histogram === Array[Double](0, 1, 1, 1))
assert(summarizer4.countInvalid === 2)
assert(summarizer4.numClasses === 4)
// small map merges large one
val summarizerA = summarizer1.merge(summarizer2)
assert(summarizerA.hashCode() === summarizer2.hashCode())
- assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1))
assert(summarizerA.countInvalid === 0)
assert(summarizerA.numClasses === 7)
// large map merges small one
val summarizerB = summarizer3.merge(summarizer4)
assert(summarizerB.hashCode() === summarizer3.hashCode())
- assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2))
+ assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3))
assert(summarizerB.countInvalid === 5)
assert(summarizerB.numClasses === 5)
}
+ test("MultiClassSummarizer with weighted samples") {
+ val summarizer1 = (new MultiClassSummarizer)
+ .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1)
+ assert(Vectors.dense(summarizer1.histogram) ~==
+ Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10)
+ assert(summarizer1.countInvalid === 0)
+ assert(summarizer1.numClasses === 7)
+
+ val summarizer2 = (new MultiClassSummarizer)
+ .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0)
+ assert(Vectors.dense(summarizer2.histogram) ~==
+ Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10)
+ assert(summarizer2.countInvalid === 0)
+ assert(summarizer2.numClasses === 6)
+
+ val summarizer = summarizer1.merge(summarizer2)
+ assert(Vectors.dense(summarizer.histogram) ~==
+ Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10)
+ assert(summarizer.countInvalid === 0)
+ assert(summarizer.numClasses === 7)
+ }
+
test("binary logistic regression with intercept without regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
@@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
b = \log{P(1) / P(0)} = \log{count_1 / count_0}
}}}
*/
- val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
+ val interceptTheory = math.log(histogram(1) / histogram(0))
val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptTheory relTol 1E-5)
@@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.forall(x => x(0) >= x(1)))
}
+
+ test("binary logistic regression with weighted samples") {
+ val (dataset, weightedDataset) = {
+ val nPoints = 1000
+ val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
+
+ // Let's over-sample the positive samples twice.
+ val data1 = testData.flatMap { case labeledPoint: LabeledPoint =>
+ if (labeledPoint.label == 1.0) {
+ Iterator(labeledPoint, labeledPoint)
+ } else {
+ Iterator(labeledPoint)
+ }
+ }
+
+ val rnd = new Random(8392)
+ val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) =>
+ if (rnd.nextGaussian() > 0.0) {
+ if (label == 1.0) {
+ Iterator(
+ Instance(label, 1.2, features),
+ Instance(label, 0.8, features),
+ Instance(0.0, 0.0, features))
+ } else {
+ Iterator(
+ Instance(label, 0.3, features),
+ Instance(1.0, 0.0, features),
+ Instance(label, 0.1, features),
+ Instance(label, 0.6, features))
+ }
+ } else {
+ if (label == 1.0) {
+ Iterator(Instance(label, 2.0, features))
+ } else {
+ Iterator(Instance(label, 1.0, features))
+ }
+ }
+ }
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ }
+
+ val trainer1a = (new LogisticRegression).setFitIntercept(true)
+ .setRegParam(0.0).setStandardization(true)
+ val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
+ .setRegParam(0.0).setStandardization(true)
+ val model1a0 = trainer1a.fit(dataset)
+ val model1a1 = trainer1a.fit(weightedDataset)
+ val model1b = trainer1b.fit(weightedDataset)
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 07efde4f5e..b6d41db69b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
+
+ test("merging summarizer with weighted samples") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1)
+ .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge(
+ (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15)
+ .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05))
+
+ assert(summarizer.count === 4)
+
+ // The following values are hand calculated using the formula:
+ // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+ // which defines the reliability weight used for computing the unbiased estimation of variance
+ // for weighted instances.
+ assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44))
+ absTol 1E-10, "mean mismatch")
+ assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
+ absTol 1E-8, "variance mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
+ absTol 1E-10, "numNonzeros mismatch")
+ assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
+ assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
+ assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192)
+ absTol 1E-8, "normL2 mismatch")
+ assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
+ }
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 87b141cd3b..46026c1e90 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -45,7 +45,15 @@ object MimaExcludes {
excludePackage("org.apache.spark.sql.execution")
) ++
MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
- MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils")
+ MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
+ Seq(
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.ml.classification.LogisticCostFun.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.ml.classification.LogisticAggregator.add"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.ml.classification.LogisticAggregator.count")
+ )
case v if v.startsWith("1.5") =>
Seq(
MimaBuild.excludeSparkPackage("network"),