diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-10-06 21:10:17 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-10-06 21:10:17 -0700 |
commit | 3713bb199142c5e06e2e527c99650f02f41f47b1 (patch) | |
tree | 228624cca24862f0e8bf8c1b1d76a5b6647e4c74 /mllib/src/main | |
parent | 49d11d49983fbe270f4df4fb1e34b5fbe854c5ec (diff) | |
download | spark-3713bb199142c5e06e2e527c99650f02f41f47b1.tar.gz spark-3713bb199142c5e06e2e527c99650f02f41f47b1.tar.bz2 spark-3713bb199142c5e06e2e527c99650f02f41f47b1.zip |
[SPARK-17792][ML] L-BFGS solver for linear regression does not accept general numeric label column types
## What changes were proposed in this pull request?
Before, we computed `instances` in LinearRegression in two spots, even though they did the same thing. One of them did not cast the label column to `DoubleType`. This patch consolidates the computation and always casts the label column to `DoubleType`.
## How was this patch tested?
Added a unit test to check all solvers. This test failed before this patch.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #15364 from sethah/linreg_numeric_type.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 536c58f998..025ed20c75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -188,17 +188,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select( + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), $(standardization), true) @@ -221,12 +222,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) |