aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala102
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala27
2 files changed, 120 insertions, 9 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index cce39f382f..f5219f9f57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,11 +17,14 @@
package org.apache.spark.ml.classification
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
- sqlContext.createDataFrame(
- generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42))
+ sqlContext.createDataFrame(sc.parallelize(testData, 4))
}
}
@@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol === "prediction")
assert(lr.getRawPredictionCol === "rawPrediction")
assert(lr.getProbabilityCol === "probability")
+ assert(lr.getWeightCol === "")
assert(lr.getFitIntercept)
assert(lr.getStandardization)
val model = lr.fit(dataset)
@@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("MultiClassSummarizer") {
val summarizer1 = (new MultiClassSummarizer)
.add(0.0).add(3.0).add(4.0).add(3.0).add(6.0)
- assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2))
+ assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1))
assert(summarizer1.countInvalid === 0)
assert(summarizer1.numClasses === 7)
val summarizer2 = (new MultiClassSummarizer)
.add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0)
- assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1))
assert(summarizer2.countInvalid === 0)
assert(summarizer2.numClasses === 6)
val summarizer3 = (new MultiClassSummarizer)
.add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0)
- assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2))
+ assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3))
assert(summarizer3.countInvalid === 3)
assert(summarizer3.numClasses === 5)
val summarizer4 = (new MultiClassSummarizer)
.add(3.1).add(4.3).add(2.0).add(1.0).add(3.0)
- assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizer4.histogram === Array[Double](0, 1, 1, 1))
assert(summarizer4.countInvalid === 2)
assert(summarizer4.numClasses === 4)
// small map merges large one
val summarizerA = summarizer1.merge(summarizer2)
assert(summarizerA.hashCode() === summarizer2.hashCode())
- assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2))
+ assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1))
assert(summarizerA.countInvalid === 0)
assert(summarizerA.numClasses === 7)
// large map merges small one
val summarizerB = summarizer3.merge(summarizer4)
assert(summarizerB.hashCode() === summarizer3.hashCode())
- assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2))
+ assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3))
assert(summarizerB.countInvalid === 5)
assert(summarizerB.numClasses === 5)
}
+ test("MultiClassSummarizer with weighted samples") {
+ val summarizer1 = (new MultiClassSummarizer)
+ .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1)
+ assert(Vectors.dense(summarizer1.histogram) ~==
+ Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10)
+ assert(summarizer1.countInvalid === 0)
+ assert(summarizer1.numClasses === 7)
+
+ val summarizer2 = (new MultiClassSummarizer)
+ .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0)
+ assert(Vectors.dense(summarizer2.histogram) ~==
+ Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10)
+ assert(summarizer2.countInvalid === 0)
+ assert(summarizer2.numClasses === 6)
+
+ val summarizer = summarizer1.merge(summarizer2)
+ assert(Vectors.dense(summarizer.histogram) ~==
+ Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10)
+ assert(summarizer.countInvalid === 0)
+ assert(summarizer.numClasses === 7)
+ }
+
test("binary logistic regression with intercept without regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true)
val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false)
@@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
b = \log{P(1) / P(0)} = \log{count_1 / count_0}
}}}
*/
- val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble)
+ val interceptTheory = math.log(histogram(1) / histogram(0))
val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptTheory relTol 1E-5)
@@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.forall(x => x(0) >= x(1)))
}
+
+ test("binary logistic regression with weighted samples") {
+ val (dataset, weightedDataset) = {
+ val nPoints = 1000
+ val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191)
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)
+
+ // Let's over-sample the positive samples twice.
+ val data1 = testData.flatMap { case labeledPoint: LabeledPoint =>
+ if (labeledPoint.label == 1.0) {
+ Iterator(labeledPoint, labeledPoint)
+ } else {
+ Iterator(labeledPoint)
+ }
+ }
+
+ val rnd = new Random(8392)
+ val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) =>
+ if (rnd.nextGaussian() > 0.0) {
+ if (label == 1.0) {
+ Iterator(
+ Instance(label, 1.2, features),
+ Instance(label, 0.8, features),
+ Instance(0.0, 0.0, features))
+ } else {
+ Iterator(
+ Instance(label, 0.3, features),
+ Instance(1.0, 0.0, features),
+ Instance(label, 0.1, features),
+ Instance(label, 0.6, features))
+ }
+ } else {
+ if (label == 1.0) {
+ Iterator(Instance(label, 2.0, features))
+ } else {
+ Iterator(Instance(label, 1.0, features))
+ }
+ }
+ }
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ }
+
+ val trainer1a = (new LogisticRegression).setFitIntercept(true)
+ .setRegParam(0.0).setStandardization(true)
+ val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
+ .setRegParam(0.0).setStandardization(true)
+ val model1a0 = trainer1a.fit(dataset)
+ val model1a1 = trainer1a.fit(weightedDataset)
+ val model1b = trainer1b.fit(weightedDataset)
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 07efde4f5e..b6d41db69b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
+
+ test("merging summarizer with weighted samples") {
+ val summarizer = (new MultivariateOnlineSummarizer)
+ .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1)
+ .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge(
+ (new MultivariateOnlineSummarizer)
+ .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15)
+ .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05))
+
+ assert(summarizer.count === 4)
+
+ // The following values are hand calculated using the formula:
+ // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+ // which defines the reliability weight used for computing the unbiased estimation of variance
+ // for weighted instances.
+ assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44))
+ absTol 1E-10, "mean mismatch")
+ assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
+ absTol 1E-8, "variance mismatch")
+ assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
+ absTol 1E-10, "numNonzeros mismatch")
+ assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
+ assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
+ assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192)
+ absTol 1E-8, "normL2 mismatch")
+ assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
+ }
}