aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-02-22 11:50:24 -0800
committerFelix Cheung <felixcheung@apache.org>2017-02-22 11:50:24 -0800
commit1f86e795b87ba93640062f29e87a032924d94b2a (patch)
tree5489ebff9dd4106fd5f508365d19ff60dccf162b /mllib
parente4065376d2b4eec178a119476fa95b26f440c076 (diff)
downloadspark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.gz
spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.bz2
spark-1f86e795b87ba93640062f29e87a032924d94b2a.zip
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs
## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16945 from wangmiao1981/svc.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala7
4 files changed, 15 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index bd965acf56..0bf543d888 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
}
- def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
+ def fit(
+ formula: String,
+ data: DataFrame,
+ aggregationDepth: Int): AFTSurvivalRegressionWrapper = {
val (rewritedFormula, censorCol) = formulaRewrite(formula)
@@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setAggregationDepth(aggregationDepth)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 78f401f29b..cbd6cd1c79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
- .setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)
+
+ if (weightCol != null) glr.setWeightCol(weightCol)
+
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index 48632316f3..d31ebb46af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper
val isotonicRegression = new IsotonicRegression()
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
- .setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
+ if (weightCol != null) isotonicRegression.setWeightCol(weightCol)
+
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 645bc7247f..c96f99cb83 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
- weightCol: String
+ weightCol: String,
+ aggregationDepth: Int
): LogisticRegressionWrapper = {
val rFormula = new RFormula()
@@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper
.setFitIntercept(fitIntercept)
.setFamily(family)
.setStandardization(standardization)
- .setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+ .setAggregationDepth(aggregationDepth)
if (thresholds.length > 1) {
lr.setThresholds(thresholds)
@@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper
lr.setThreshold(thresholds(0))
}
+ if (weightCol != null) lr.setWeightCol(weightCol)
+
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)