aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorImran Younus <iyounus@us.ibm.com>2016-01-05 11:48:45 +0000
committerSean Owen <sowen@cloudera.com>2016-01-05 11:48:45 +0000
commit1cdc42d2b99edfec01066699a7620cca02b61f0e (patch)
tree5c9ed178231a34527cdabbd62cad7254228b7243 /mllib
parent8eb2dc7133b4d2143adffc2bdbb61d96bd41a0ac (diff)
downloadspark-1cdc42d2b99edfec01066699a7620cca02b61f0e.tar.gz
spark-1cdc42d2b99edfec01066699a7620cca02b61f0e.tar.bz2
spark-1cdc42d2b99edfec01066699a7620cca02b61f0e.zip
[SPARK-12331][ML] R^2 for regression through the origin.
Modified the definition of R^2 for regression through origin. Added modified test for regression metrics. Author: Imran Younus <iyounus@us.ibm.com> Author: Imran Younus <imranyounus@gmail.com> Closes #10384 from iyounus/SPARK_12331_R2_for_regression_through_origin.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala156
3 files changed, 112 insertions, 71 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index dee26337dc..c54e08b2ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -534,7 +534,8 @@ class LinearRegressionSummary private[regression] (
@transient private val metrics = new RegressionMetrics(
predictions
.select(predictionCol, labelCol)
- .map { case Row(pred: Double, label: Double) => (pred, label) } )
+ .map { case Row(pred: Double, label: Double) => (pred, label) },
+ !model.getFitIntercept)
/**
* Returns the explained variance regression score.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 34883f2f39..18c90b204a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -27,11 +27,18 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
- * @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ * @param predictionAndObservations an RDD of (prediction, observation) pairs
+ * @param throughOrigin True if the regression is through the origin. For example, in linear
+ * regression, it will be true without fitting intercept.
*/
@Since("1.2.0")
-class RegressionMetrics @Since("1.2.0") (
- predictionAndObservations: RDD[(Double, Double)]) extends Logging {
+class RegressionMetrics @Since("2.0.0") (
+ predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
+ extends Logging {
+
+ @Since("1.2.0")
+ def this(predictionAndObservations: RDD[(Double, Double)]) =
+ this(predictionAndObservations, false)
/**
* An auxiliary constructor taking a DataFrame.
@@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") (
)
summary
}
+
+ private lazy val SSy = math.pow(summary.normL2(0), 2)
private lazy val SSerr = math.pow(summary.normL2(1), 2)
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SSreg = {
@@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") (
/**
* Returns R^2^, the unadjusted coefficient of determination.
* @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ * In case of regression through the origin, the definition of R^2^ is to be modified.
+ * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 25, 76-80 (2003)
+ * [[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]]
*/
@Since("1.2.0")
def r2: Double = {
- 1 - SSerr / SStot
+ if (throughOrigin) {
+ 1 - SSerr / SSy
+ } else {
+ 1 - SSerr / SStot
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 4b7f1be58f..f1d5173836 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
+ val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51)
+ val eps = 1E-5
test("regression metrics for unbiased (includes intercept term) predictor") {
/* Verify results in R:
- preds = c(2.25, -0.25, 1.75, 7.75)
- obs = c(3.0, -0.5, 2.0, 7.0)
-
- SStot = sum((obs - mean(obs))^2)
- SSreg = sum((preds - mean(obs))^2)
- SSerr = sum((obs - preds)^2)
-
- explainedVariance = SSreg / length(obs)
- explainedVariance
- > [1] 8.796875
- meanAbsoluteError = mean(abs(preds - obs))
- meanAbsoluteError
- > [1] 0.5
- meanSquaredError = mean((preds - obs)^2)
- meanSquaredError
- > [1] 0.3125
- rmse = sqrt(meanSquaredError)
- rmse
- > [1] 0.559017
- r2 = 1 - SSerr / SStot
- r2
- > [1] 0.9571734
+ y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+ x = c(16, 22, 14, 10, 13, 19, 12, 18, 11)
+ df <- as.data.frame(cbind(x, y))
+ model <- lm(y ~ x, data=df)
+ preds = signif(predict(model), digits = 4)
+ preds
+ 1 2 3 4 5 6 7 8 9
+ 72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58
+ options(digits=8)
+ explainedVariance = mean((preds - mean(y))^2)
+ [1] 157.3
+ meanAbsoluteError = mean(abs(preds - y))
+ meanAbsoluteError
+ [1] 3.7355556
+ meanSquaredError = mean((preds - y)^2)
+ meanSquaredError
+ [1] 17.539511
+ rmse = sqrt(meanSquaredError)
+ rmse
+ [1] 4.18802
+ r2 = summary(model)$r.squared
+ r2
+ [1] 0.89968225
*/
- val predictionAndObservations = sc.parallelize(
- Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
+ val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 55.58)
+ val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 157.3 absTol eps,
"explained variance regression score mismatch")
- assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
- assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
- assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
+ assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps,
"root mean squared error mismatch")
- assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
+ assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch")
}
test("regression metrics for biased (no intercept term) predictor") {
/* Verify results in R:
- preds = c(2.5, 0.0, 2.0, 8.0)
- obs = c(3.0, -0.5, 2.0, 7.0)
-
- SStot = sum((obs - mean(obs))^2)
- SSreg = sum((preds - mean(obs))^2)
- SSerr = sum((obs - preds)^2)
-
- explainedVariance = SSreg / length(obs)
- explainedVariance
- > [1] 8.859375
- meanAbsoluteError = mean(abs(preds - obs))
- meanAbsoluteError
- > [1] 0.5
- meanSquaredError = mean((preds - obs)^2)
- meanSquaredError
- > [1] 0.375
- rmse = sqrt(meanSquaredError)
- rmse
- > [1] 0.6123724
- r2 = 1 - SSerr / SStot
- r2
- > [1] 0.9486081
+ y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+ x = c(16, 22, 14, 10, 13, 19, 12, 18, 11)
+ df <- as.data.frame(cbind(x, y))
+ model <- lm(y ~ 0 + x, data=df)
+ preds = signif(predict(model), digits = 4)
+ preds
+ 1 2 3 4 5 6 7 8 9
+ 72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58
+ options(digits=8)
+ explainedVariance = mean((preds - mean(y))^2)
+ explainedVariance
+ [1] 294.88167
+ meanAbsoluteError = mean(abs(preds - y))
+ meanAbsoluteError
+ [1] 4.5888889
+ meanSquaredError = mean((preds - y)^2)
+ meanSquaredError
+ [1] 39.958711
+ rmse = sqrt(meanSquaredError)
+ rmse
+ [1] 6.3212903
+ r2 = summary(model)$r.squared
+ r2
+ [1] 0.99185395
*/
- val predictionAndObservations = sc.parallelize(
- Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
- val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
+ val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 49.58)
+ val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations, true)
+ assert(metrics.explainedVariance ~== 294.88167 absTol eps,
"explained variance regression score mismatch")
- assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
- assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
- assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
+ assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps,
"root mean squared error mismatch")
- assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
+ assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch")
}
test("regression metrics with complete fitting") {
- val predictionAndObservations = sc.parallelize(
- Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
+ /* Verify results in R:
+ y = c(77, 85, 62, 55, 63, 88, 57, 81, 51)
+ preds = y
+ explainedVariance = mean((preds - mean(y))^2)
+ explainedVariance
+ [1] 174.8395
+ meanAbsoluteError = mean(abs(preds - y))
+ meanAbsoluteError
+ [1] 0
+ meanSquaredError = mean((preds - y)^2)
+ meanSquaredError
+ [1] 0
+ rmse = sqrt(meanSquaredError)
+ rmse
+ [1] 0
+ r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2)
+ r2
+ [1] 1
+ */
+ val preds = obs
+ val predictionAndObservations = sc.parallelize(preds.zip(obs), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
- assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
+ assert(metrics.explainedVariance ~== 174.83951 absTol eps,
"explained variance regression score mismatch")
- assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
- assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
- assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
+ assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps,
"root mean squared error mismatch")
- assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
+ assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
}
}