aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-03 08:29:07 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-03 08:29:07 -0800
commitd6f10aa7ea2806c0fbcfc31d7dee91d28319fab7 (patch)
treec806dd442ff3b8c45294ad8d0ffc60882578615e /mllib/src/test/scala/org
parentd6035d97c91fe78b1336ade48134252915263ea6 (diff)
downloadspark-d6f10aa7ea2806c0fbcfc31d7dee91d28319fab7.tar.gz
spark-d6f10aa7ea2806c0fbcfc31d7dee91d28319fab7.tar.bz2
spark-d6f10aa7ea2806c0fbcfc31d7dee91d28319fab7.zip
[SPARK-9836][ML] Provide R-like summary statistics for OLS via normal equation solver
https://issues.apache.org/jira/browse/SPARK-9836 Author: Yanbo Liang <ybliang8@gmail.com> Closes #9413 from yanboliang/spark-9836.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala129
1 files changed, 129 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 235c796d78..fbf83e8922 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -35,6 +35,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var datasetWithDenseFeature: DataFrame = _
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _
@transient var datasetWithSparseFeature: DataFrame = _
+ @transient var datasetWithWeight: DataFrame = _
/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
@@ -73,6 +74,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
xMean = Seq.fill(featureSize)(r.nextDouble).toArray,
xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200,
seed, eps = 0.1, sparsity = 0.7), 2))
+
+ /*
+ R code:
+
+ A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+ b <- c(17, 19, 23, 29)
+ w <- c(1, 2, 3, 4)
+ df <- as.data.frame(cbind(A, b))
+ */
+ datasetWithWeight = sqlContext.createDataFrame(
+ sc.parallelize(Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
+ Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
+ Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
+ ), 2))
}
test("params") {
@@ -603,6 +620,16 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
// To clalify that the normal solver is used here.
assert(model.summary.objectiveHistory.length == 1)
assert(model.summary.objectiveHistory(0) == 0.0)
+ val devianceResidualsR = Array(-0.35566, 0.34504)
+ val seCoefR = Array(0.0011756, 0.0009032)
+ val tValsR = Array(3998, 7971)
+ val pValsR = Array(0, 0)
+ model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.tValues.map(_.round).zip(tValsR).foreach{ x => assert(x._1 === x._2) }
+ model.summary.pValues.map(_.round).zip(pValsR).foreach{ x => assert(x._1 === x._2) }
}
}
}
@@ -725,4 +752,106 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.sliding(2)
.forall(x => x(0) >= x(1)))
}
+
+ test("linear regression summary with weighted samples and intercept by normal solver") {
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ .", data = df, weights = w)
+ summary(model)
+
+ Call:
+ glm(formula = "b ~ .", data = df, weights = w)
+
+ Deviance Residuals:
+ 1 2 3 4
+ 1.920 -1.358 -1.109 0.960
+
+ Coefficients:
+ Estimate Std. Error t value Pr(>|t|)
+ (Intercept) 18.080 9.608 1.882 0.311
+ V1 6.080 5.556 1.094 0.471
+ V2 -0.600 1.960 -0.306 0.811
+
+ (Dispersion parameter for gaussian family taken to be 7.68)
+
+ Null deviance: 202.00 on 3 degrees of freedom
+ Residual deviance: 7.68 on 1 degrees of freedom
+ AIC: 18.783
+
+ Number of Fisher Scoring iterations: 2
+ */
+
+ val model = new LinearRegression()
+ .setWeightCol("weight")
+ .setSolver("normal")
+ .fit(datasetWithWeight)
+ val coefficientsR = Vectors.dense(Array(6.080, -0.600))
+ val interceptR = 18.080
+ val devianceResidualsR = Array(-1.358, 1.920)
+ val seCoefR = Array(5.556, 1.960)
+ val tValsR = Array(1.094, -0.306)
+ val pValsR = Array(0.471, 0.811)
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept ~== interceptR absTol 1E-3)
+ model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ }
+
+ test("linear regression summary with weighted samples and w/o intercept by normal solver") {
+ /*
+ R code:
+
+ model <- glm(formula = "b ~ . -1", data = df, weights = w)
+ summary(model)
+
+ Call:
+ glm(formula = "b ~ . -1", data = df, weights = w)
+
+ Deviance Residuals:
+ 1 2 3 4
+ 1.950 2.344 -4.600 2.103
+
+ Coefficients:
+ Estimate Std. Error t value Pr(>|t|)
+ V1 -3.7271 2.9032 -1.284 0.3279
+ V2 3.0100 0.6022 4.998 0.0378 *
+ ---
+ Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
+
+ (Dispersion parameter for gaussian family taken to be 17.4376)
+
+ Null deviance: 5962.000 on 4 degrees of freedom
+ Residual deviance: 34.875 on 2 degrees of freedom
+ AIC: 22.835
+
+ Number of Fisher Scoring iterations: 2
+ */
+
+ val model = new LinearRegression()
+ .setWeightCol("weight")
+ .setSolver("normal")
+ .setFitIntercept(false)
+ .fit(datasetWithWeight)
+ val coefficientsR = Vectors.dense(Array(-3.7271, 3.0100))
+ val interceptR = 0.0
+ val devianceResidualsR = Array(-4.600, 2.344)
+ val seCoefR = Array(2.9032, 0.6022)
+ val tValsR = Array(-1.284, 4.998)
+ val pValsR = Array(0.3279, 0.0378)
+
+ assert(model.coefficients ~== coefficientsR absTol 1E-3)
+ assert(model.intercept === interceptR)
+ model.summary.devianceResiduals.zip(devianceResidualsR).foreach { x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.coefficientStandardErrors.zip(seCoefR).foreach{ x =>
+ assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 absTol 1E-3) }
+ }
}