From 1f86e795b87ba93640062f29e87a032924d94b2a Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 22 Feb 2017 11:50:24 -0800 Subject: [SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs ## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #16945 from wangmiao1981/svc. --- .../scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala | 6 +++++- .../org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala | 4 +++- .../scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala | 3 ++- .../scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala | 7 +++++-- 4 files changed, 15 insertions(+), 5 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index bd965acf56..0bf543d888 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg } - def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + def fit( + formula: String, + data: DataFrame, + aggregationDepth: Int): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) .setFeaturesCol(rFormula.getFeaturesCol) + .setAggregationDepth(aggregationDepth) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 78f401f29b..cbd6cd1c79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) - .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + + if (weightCol != null) glr.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 48632316f3..d31ebb46af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper val isotonicRegression = new IsotonicRegression() .setIsotonic(isotonic) .setFeatureIndex(featureIndex) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) + if (weightCol != null) isotonicRegression.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 645bc7247f..c96f99cb83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String + weightCol: String, + aggregationDepth: Int ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper .setFitIntercept(fitIntercept) .setFamily(family) .setStandardization(standardization) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + .setAggregationDepth(aggregationDepth) if (thresholds.length > 1) { lr.setThresholds(thresholds) @@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper lr.setThreshold(thresholds(0)) } + if (weightCol != null) lr.setWeightCol(weightCol) + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL) -- cgit v1.2.3