diff options
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc..8428f4f00b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } |