diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-08-10 10:53:48 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2016-08-10 10:53:48 -0700 |
commit | d4a9122430d6c3aeaaee32aa09d314016ff6ddc7 (patch) | |
tree | 19e191c481ca385c3fa93b62b5b573c44e5b637c /mllib/src | |
parent | 19af298bb6d264adcf02f6f84c8dc1542b408507 (diff) | |
download | spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.gz spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.tar.bz2 spark-d4a9122430d6c3aeaaee32aa09d314016ff6ddc7.zip |
[SPARK-16710][SPARKR][ML] spark.glm should support weightCol
## What changes were proposed in this pull request?
Training GLMs on weighted dataset is very important use cases, but it is not supported by SparkR currently. Users can pass argument ```weights``` to specify the weights vector in native R. For ```spark.glm```, we can pass in the ```weightCol``` which is consistent with MLlib.
## How was this patch tested?
Unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #14346 from yanboliang/spark-16710.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 5642abc645..0d3181d0ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -68,7 +68,8 @@ private[r] object GeneralizedLinearRegressionWrapper family: String, link: String, tol: Double, - maxIter: Int): GeneralizedLinearRegressionWrapper = { + maxIter: Int, + weightCol: String): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) val rFormulaModel = rFormula.fit(data) @@ -84,6 +85,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) + .setWeightCol(weightCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) |