aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorMeihua Wu <meihuawu@umich.edu>2015-09-21 12:09:00 -0700
committerDB Tsai <dbt@netflix.com>2015-09-21 12:09:00 -0700
commit331f0b10f78a37d96d3e573d211d74a0935265db (patch)
tree73fd4850825c1254df3967496ab4ee4625cc3454 /mllib/src/test/scala/org
parentca9fe540fe04e2e230d1e76526b5502bab152914 (diff)
downloadspark-331f0b10f78a37d96d3e573d211d74a0935265db.tar.gz
spark-331f0b10f78a37d96d3e573d211d74a0935265db.tar.bz2
spark-331f0b10f78a37d96d3e573d211d74a0935265db.zip
[SPARK-9642] [ML] LinearRegression should supported weighted data
In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu <meihuawu@umich.edu> Closes #8631 from rotationsymmetry/SPARK-9642.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala88
1 files changed, 88 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 2aaee71ecc..8428f4f00b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.ml.regression
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.zip(testSummary.residuals.select("residuals").collect())
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
}
+
+ test("linear regression with weighted samples"){
+ val (data, weightedData) = {
+ val activeData = LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+
+ val rnd = new Random(8392)
+ val signedData = activeData.map { case p: LabeledPoint =>
+ (rnd.nextGaussian() > 0.0, p)
+ }
+
+ val data1 = signedData.flatMap {
+ case (true, p) => Iterator(p, p)
+ case (false, p) => Iterator(p)
+ }
+
+ val weightedSignedData = signedData.flatMap {
+ case (true, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 1.2, features),
+ Instance(label, weight = 0.8, features)
+ )
+ case (false, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 0.3, features),
+ Instance(label, weight = 0.1, features),
+ Instance(label, weight = 0.6, features)
+ )
+ }
+
+ val noiseData = LinearDataGenerator.generateLinearInput(
+ 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+ val weightedNoiseData = noiseData.map {
+ case LabeledPoint(label, features) => Instance(label, weight = 0, features)
+ }
+ val data2 = weightedSignedData ++ weightedNoiseData
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ }
+
+ val trainer1a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val model1a0 = trainer1a.fit(data)
+ val model1a1 = trainer1a.fit(weightedData)
+ val model1b = trainer1b.fit(weightedData)
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ val trainer2a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val model2a0 = trainer2a.fit(data)
+ val model2a1 = trainer2a.fit(weightedData)
+ val model2b = trainer2b.fit(weightedData)
+ assert(model2a0.weights !~= model2a1.weights absTol 1E-3)
+ assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3)
+ assert(model2a0.weights ~== model2b.weights absTol 1E-3)
+ assert(model2a0.intercept ~== model2b.intercept absTol 1E-3)
+
+ val trainer3a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val model3a0 = trainer3a.fit(data)
+ val model3a1 = trainer3a.fit(weightedData)
+ val model3b = trainer3b.fit(weightedData)
+ assert(model3a0.weights !~= model3a1.weights absTol 1E-3)
+ assert(model3a0.weights ~== model3b.weights absTol 1E-3)
+
+ val trainer4a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val model4a0 = trainer4a.fit(data)
+ val model4a1 = trainer4a.fit(weightedData)
+ val model4b = trainer4b.fit(weightedData)
+ assert(model4a0.weights !~= model4a1.weights absTol 1E-3)
+ assert(model4a0.weights ~== model4b.weights absTol 1E-3)
+ }
}