aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-02-22 11:50:24 -0800
committerFelix Cheung <felixcheung@apache.org>2017-02-22 11:50:24 -0800
commit1f86e795b87ba93640062f29e87a032924d94b2a (patch)
tree5489ebff9dd4106fd5f508365d19ff60dccf162b
parente4065376d2b4eec178a119476fa95b26f440c076 (diff)
downloadspark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.gz
spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.bz2
spark-1f86e795b87ba93640062f29e87a032924d94b2a.zip
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs
## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16945 from wangmiao1981/svc.
-rw-r--r--R/pkg/R/generics.R2
-rw-r--r--R/pkg/R/mllib_classification.R13
-rw-r--r--R/pkg/R/mllib_regression.R24
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_classification.R10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala7
8 files changed, 50 insertions, 19 deletions
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 11940d3560..647cbbdd82 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest",
#' @rdname spark.survreg
#' @export
-setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
+setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
#' @rdname spark.svmLinear
#' @export
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index fa0d795faa..05bb952661 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
+#' or the number of partitions are large, this param could be adjusted to a larger size.
+#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
@@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) {
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
- thresholds = 0.5, weightCol = NULL) {
+ thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
@@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
- as.character(weightCol))
+ weightCol, as.integer(aggregationDepth))
new("LogisticRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R
index 96ee220bc4..ac0578c4ab 100644
--- a/R/pkg/R/mllib_regression.R
+++ b/R/pkg/R/mllib_regression.R
@@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
}
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, tolower(family$family), family$link,
- tol, as.integer(maxIter), as.character(weightCol), regParam)
+ tol, as.integer(maxIter), weightCol, regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
@@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
- as.character(weightCol))
+ weightCol)
new("IsotonicRegressionModel", jobj = jobj)
})
@@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
+#' or the number of partitions are large, this param could be adjusted to a larger size.
+#' This is an expert parameter. Default value should be good for most cases.
+#' @param ... additional arguments passed to the method.
#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/package=survival}
@@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula) {
+ function(data, formula, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
- "fit", formula, data@sdf)
+ "fit", formula, data@sdf, as.integer(aggregationDepth))
new("AFTSurvivalRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 620f528f2e..459254d271 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -211,7 +211,15 @@ test_that("spark.logit", {
df <- createDataFrame(data)
model <- spark.logit(df, label ~ feature)
prediction <- collect(select(predict(model, df), "prediction"))
- expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
+ expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
+
+ # Test prediction with weightCol
+ weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
+ data2 <- as.data.frame(cbind(label, feature, weight))
+ df2 <- createDataFrame(data2)
+ model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
+ prediction2 <- collect(select(predict(model2, df2), "prediction"))
+ expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0"))
})
test_that("spark.mlp", {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index bd965acf56..0bf543d888 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
}
- def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
+ def fit(
+ formula: String,
+ data: DataFrame,
+ aggregationDepth: Int): AFTSurvivalRegressionWrapper = {
val (rewritedFormula, censorCol) = formulaRewrite(formula)
@@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setAggregationDepth(aggregationDepth)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 78f401f29b..cbd6cd1c79 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
- .setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)
+
+ if (weightCol != null) glr.setWeightCol(weightCol)
+
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index 48632316f3..d31ebb46af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper
val isotonicRegression = new IsotonicRegression()
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
- .setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
+ if (weightCol != null) isotonicRegression.setWeightCol(weightCol)
+
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 645bc7247f..c96f99cb83 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
- weightCol: String
+ weightCol: String,
+ aggregationDepth: Int
): LogisticRegressionWrapper = {
val rFormula = new RFormula()
@@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper
.setFitIntercept(fitIntercept)
.setFamily(family)
.setStandardization(standardization)
- .setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+ .setAggregationDepth(aggregationDepth)
if (thresholds.length > 1) {
lr.setThresholds(thresholds)
@@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper
lr.setThreshold(thresholds(0))
}
+ if (weightCol != null) lr.setWeightCol(weightCol)
+
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)