aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-02-22 11:50:24 -0800
committerFelix Cheung <felixcheung@apache.org>2017-02-22 11:50:24 -0800
commit1f86e795b87ba93640062f29e87a032924d94b2a (patch)
tree5489ebff9dd4106fd5f508365d19ff60dccf162b /R
parente4065376d2b4eec178a119476fa95b26f440c076 (diff)
downloadspark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.gz
spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.bz2
spark-1f86e795b87ba93640062f29e87a032924d94b2a.zip
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs
## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16945 from wangmiao1981/svc.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/generics.R2
-rw-r--r--R/pkg/R/mllib_classification.R13
-rw-r--r--R/pkg/R/mllib_regression.R24
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_classification.R10
4 files changed, 35 insertions, 14 deletions
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 11940d3560..647cbbdd82 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest",
#' @rdname spark.survreg
#' @export
-setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
+setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
#' @rdname spark.svmLinear
#' @export
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index fa0d795faa..05bb952661 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
+#' or the number of partitions are large, this param could be adjusted to a larger size.
+#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
@@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) {
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
- thresholds = 0.5, weightCol = NULL) {
+ thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
@@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
- as.character(weightCol))
+ weightCol, as.integer(aggregationDepth))
new("LogisticRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R
index 96ee220bc4..ac0578c4ab 100644
--- a/R/pkg/R/mllib_regression.R
+++ b/R/pkg/R/mllib_regression.R
@@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
}
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, tolower(family$family), family$link,
- tol, as.integer(maxIter), as.character(weightCol), regParam)
+ tol, as.integer(maxIter), weightCol, regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
@@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
- as.character(weightCol))
+ weightCol)
new("IsotonicRegressionModel", jobj = jobj)
})
@@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
+#' or the number of partitions are large, this param could be adjusted to a larger size.
+#' This is an expert parameter. Default value should be good for most cases.
+#' @param ... additional arguments passed to the method.
#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/package=survival}
@@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula) {
+ function(data, formula, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
- "fit", formula, data@sdf)
+ "fit", formula, data@sdf, as.integer(aggregationDepth))
new("AFTSurvivalRegressionModel", jobj = jobj)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R
index 620f528f2e..459254d271 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_classification.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R
@@ -211,7 +211,15 @@ test_that("spark.logit", {
df <- createDataFrame(data)
model <- spark.logit(df, label ~ feature)
prediction <- collect(select(predict(model, df), "prediction"))
- expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
+ expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))
+
+ # Test prediction with weightCol
+ weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
+ data2 <- as.data.frame(cbind(label, feature, weight))
+ df2 <- createDataFrame(data2)
+ model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
+ prediction2 <- collect(select(predict(model2, df2), "prediction"))
+ expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0"))
})
test_that("spark.mlp", {