aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-10 10:53:48 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-08-10 10:53:48 -0700
commitd4a9122430d6c3aeaaee32aa09d314016ff6ddc7 (patch)
tree19e191c481ca385c3fa93b62b5b573c44e5b637c /mllib
parent19af298bb6d264adcf02f6f84c8dc1542b408507 (diff)
downloadspark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.gz
spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.bz2
spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.zip
[SPARK-16710][SPARKR][ML] spark.glm should support weightCol
## What changes were proposed in this pull request? Training GLMs on weighted dataset is very important use cases, but it is not supported by SparkR currently. Users can pass argument ```weights``` to specify the weights vector in native R. For ```spark.glm```, we can pass in the ```weightCol``` which is consistent with MLlib. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14346 from yanboliang/spark-16710.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
1 files changed, 3 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 5642abc645..0d3181d0ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -68,7 +68,8 @@ private[r] object GeneralizedLinearRegressionWrapper
family: String,
link: String,
tol: Double,
- maxIter: Int): GeneralizedLinearRegressionWrapper = {
+ maxIter: Int,
+ weightCol: String): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
@@ -84,6 +85,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
+ .setWeightCol(weightCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)