aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-09-15 15:46:47 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-15 15:46:47 -0700
commitbe52faa7c72fb4b95829f09a7dc5eb5dccd03524 (patch)
tree1fd30de5fdcf31c013774dca0ae06b834992900e /mllib/src/test/scala/org
parent31a229aa739b6d05ec6d91b820fcca79b6b7d6fe (diff)
downloadspark-be52faa7c72fb4b95829f09a7dc5eb5dccd03524.tar.gz
spark-be52faa7c72fb4b95829f09a7dc5eb5dccd03524.tar.bz2
spark-be52faa7c72fb4b95829f09a7dc5eb5dccd03524.zip
[SPARK-7685] [ML] Apply weights to different samples in Logistic Regression
In fraud detection dataset, almost all the samples are negative while only couple of them are positive. This type of high imbalanced data will bias the models toward negative resulting poor performance. In python-scikit, they provide a correction allowing users to Over-/undersample the samples of each class according to the given weights. In auto mode, selects weights inversely proportional to class frequencies in the training set. This can be done in a more efficient way by multiplying the weights into loss and gradient instead of doing actual over/undersampling in the training dataset which is very expensive. http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html On the other hand, some of the training data maybe more important like the training samples from tenure users while the training samples from new users maybe less important. We should be able to provide another "weight: Double" information in the LabeledPoint to weight them differently in the learning algorithm. Author: DB Tsai <dbt@netflix.com> Author: DB Tsai <dbt@dbs-mac-pro.corp.netflix.com> Closes #7884 from dbtsai/SPARK-7685.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala102
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala27
2 files changed, 120 insertions, 9 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index cce39f382f..f5219f9f57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.ml.classification
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
- sqlContext.createDataFrame(
- generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
+ sqlContext.createDataFrame(sc.parallelize(testData, 4))
}
}
@@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol === "prediction")
assert(lr.getRawPredictionCol === "rawPrediction")
assert(lr.getProbabilityCol === "probability")
+ assert(lr.getWeightCol === "")
assert(lr.getFitIntercept)
assert(lr.getStandardization)
val model = lr.fit(dataset)
@@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("MultiClassSummarizer") {
val summarizer1 = (new MultiClassSummarizer)
.add(0.0).add(3.0).add(4.0).add(3.0).add(6.0)
- assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2))
+ assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1))
assert(summarizer1.countInvalid === 0)
assert(summarizer1.numClasses === 7)
val summarizer2 = (new MultiClassSummarizer)
.add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0)
- assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1))
assert(summarizer2.countInvalid === 0)
assert(summarizer2.numClasses === 6)
val summarizer3 = (new MultiClassSummarizer)
.add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0)
- assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2))
+ assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3))
assert(summarizer3.countInvalid === 3)
assert(summarizer3.numClasses === 5)
val summarizer4 = (new MultiClassSummarizer)
.add(3.1).add(4.3).add(2.0).add(1.0).add(3.0)
- assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer4.histogram === Array[Double](0, 1, 1, 1))
assert(summarizer4.countInvalid === 2)
assert(summarizer4.numClasses === 4)
// small map merges large one
val summarizerA = summarizer1.merge(summarizer2)
assert(summarizerA.hashCode() === summarizer2.hashCode())
- assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1))
assert(summarizerA.countInvalid === 0)
assert(summarizerA.numClasses === 7)
// large map merges small one
val summarizerB = summarizer3.merge(summarizer4)
assert(summarizerB.hashCode() === summarizer3.hashCode())
- assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2))
+ assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3))
assert(summarizerB.countInvalid === 5)
assert(summarizerB.numClasses === 5)
}
+ test("MultiClassSummarizer with weighted samples") {
+ val summarizer1 = (new MultiClassSummarizer)
+ .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1)
+ assert(Vectors.dense(summarizer1.histogram) ~==
+ Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10)
+ assert(summarizer1.countInvalid === 0)
+ assert(summarizer1.numClasses === 7)
+
+ val summarizer2 = (new MultiClassSummarizer)
+ .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0)
+ assert(Vectors.dense(summarizer2.histogram) ~==
+ Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10)
+ assert(summarizer2.countInvalid === 0)
+ assert(summarizer2.numClasses === 6)
+
+ val summarizer = summarizer1.merge(summarizer2)
+ assert(Vectors.dense(summarizer.histogram) ~==
+ Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10)
+ assert(summarizer.countInvalid === 0)
+ assert(summarizer.numClasses === 7)
+ }
+
test("binary logistic regression with intercept without regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
@@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
b = \log{P(1) / P(0)} = \log{count_1 / count_0}
}}}
*/
- val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
+ val interceptTheory = math.log(histogram(1) / histogram(0))
val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptTheory relTol 1E-5)
@@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.forall(x => x(0) >= x(1)))
}
+
+ test("binary logistic regression with weighted samples") {
+ val (dataset, weightedDataset) = {
+ val nPoints = 1000
+ val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
+
+ // Let's over-sample the positive samples twice.
+ val data1 = testData.flatMap { case labeledPoint: LabeledPoint =>
+ if (labeledPoint.label == 1.0) {
+ Iterator(labeledPoint, labeledPoint)
+ } else {
+ Iterator(labeledPoint)
+ }
+ }
+
+ val rnd = new Random(8392)
+ val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) =>
+ if (rnd.nextGaussian() > 0.0) {
+ if (label == 1.0) {
+ Iterator(
+ Instance(label, 1.2, features),
+ Instance(label, 0.8, features),
+ Instance(0.0, 0.0, features))
+ } else {
+ Iterator(
+ Instance(label, 0.3, features),
+ Instance(1.0, 0.0, features),
+ Instance(label, 0.1, features),
+ Instance(label, 0.6, features))
+ }
+ } else {
+ if (label == 1.0) {
+ Iterator(Instance(label, 2.0, features))
+ } else {
+ Iterator(Instance(label, 1.0, features))
+ }
+ }
+ }
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ }
+
+ val trainer1a = (new LogisticRegression).setFitIntercept(true)
+ .setRegParam(0.0).setStandardization(true)
+ val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
+ .setRegParam(0.0).setStandardization(true)
+ val model1a0 = trainer1a.fit(dataset)
+ val model1a1 = trainer1a.fit(weightedDataset)
+ val model1b = trainer1b.fit(weightedDataset)
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 07efde4f5e..b6d41db69b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
+
+ test("merging summarizer with weighted samples") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1)
+ .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge(
+ (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15)
+ .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05))
+
+ assert(summarizer.count === 4)
+
+ // The following values are hand calculated using the formula:
+ // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+ // which defines the reliability weight used for computing the unbiased estimation of variance
+ // for weighted instances.
+ assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44))
+ absTol 1E-10, "mean mismatch")
+ assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
+ absTol 1E-8, "variance mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
+ absTol 1E-10, "numNonzeros mismatch")
+ assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
+ assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
+ assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192)
+ absTol 1E-8, "normL2 mismatch")
+ assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
+ }
}