aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala88
1 files changed, 88 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 2aaee71ecc..8428f4f00b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.ml.regression
+import scala.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
@@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.zip(testSummary.residuals.select("residuals").collect())
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
}
+
+ test("linear regression with weighted samples"){
+ val (data, weightedData) = {
+ val activeData = LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+
+ val rnd = new Random(8392)
+ val signedData = activeData.map { case p: LabeledPoint =>
+ (rnd.nextGaussian() > 0.0, p)
+ }
+
+ val data1 = signedData.flatMap {
+ case (true, p) => Iterator(p, p)
+ case (false, p) => Iterator(p)
+ }
+
+ val weightedSignedData = signedData.flatMap {
+ case (true, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 1.2, features),
+ Instance(label, weight = 0.8, features)
+ )
+ case (false, LabeledPoint(label, features)) =>
+ Iterator(
+ Instance(label, weight = 0.3, features),
+ Instance(label, weight = 0.1, features),
+ Instance(label, weight = 0.6, features)
+ )
+ }
+
+ val noiseData = LinearDataGenerator.generateLinearInput(
+ 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+ val weightedNoiseData = noiseData.map {
+ case LabeledPoint(label, features) => Instance(label, weight = 0, features)
+ }
+ val data2 = weightedSignedData ++ weightedNoiseData
+
+ (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+ sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+ }
+
+ val trainer1a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val model1a0 = trainer1a.fit(data)
+ val model1a1 = trainer1a.fit(weightedData)
+ val model1b = trainer1b.fit(weightedData)
+ assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+ assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+ assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+ assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+ val trainer2a = (new LinearRegression).setFitIntercept(true)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val model2a0 = trainer2a.fit(data)
+ val model2a1 = trainer2a.fit(weightedData)
+ val model2b = trainer2b.fit(weightedData)
+ assert(model2a0.weights !~= model2a1.weights absTol 1E-3)
+ assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3)
+ assert(model2a0.weights ~== model2b.weights absTol 1E-3)
+ assert(model2a0.intercept ~== model2b.intercept absTol 1E-3)
+
+ val trainer3a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+ val model3a0 = trainer3a.fit(data)
+ val model3a1 = trainer3a.fit(weightedData)
+ val model3b = trainer3b.fit(weightedData)
+ assert(model3a0.weights !~= model3a1.weights absTol 1E-3)
+ assert(model3a0.weights ~== model3b.weights absTol 1E-3)
+
+ val trainer4a = (new LinearRegression).setFitIntercept(false)
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight")
+ .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+ val model4a0 = trainer4a.fit(data)
+ val model4a1 = trainer4a.fit(weightedData)
+ val model4b = trainer4b.fit(weightedData)
+ assert(model4a0.weights !~= model4a1.weights absTol 1E-3)
+ assert(model4a0.weights ~== model4b.weights absTol 1E-3)
+ }
}