aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/R/mllib_classification.R
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-02-22 11:50:24 -0800
committerFelix Cheung <felixcheung@apache.org>2017-02-22 11:50:24 -0800
commit1f86e795b87ba93640062f29e87a032924d94b2a (patch)
tree5489ebff9dd4106fd5f508365d19ff60dccf162b /R/pkg/R/mllib_classification.R
parente4065376d2b4eec178a119476fa95b26f440c076 (diff)
downloadspark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.gz
spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.bz2
spark-1f86e795b87ba93640062f29e87a032924d94b2a.zip
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs
## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16945 from wangmiao1981/svc.
Diffstat (limited to 'R/pkg/R/mllib_classification.R')
-rw-r--r--R/pkg/R/mllib_classification.R13
1 files changed, 9 insertions, 4 deletions
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R
index fa0d795faa..05bb952661 100644
--- a/R/pkg/R/mllib_classification.R
+++ b/R/pkg/R/mllib_classification.R
@@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
+#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
+#' or the number of partitions are large, this param could be adjusted to a larger size.
+#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
@@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) {
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
- thresholds = 0.5, weightCol = NULL) {
+ thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
- if (is.null(weightCol)) {
- weightCol <- ""
+ if (!is.null(weightCol) && weightCol == "") {
+ weightCol <- NULL
+ } else if (!is.null(weightCol)) {
+ weightCol <- as.character(weightCol)
}
jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
@@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
- as.character(weightCol))
+ weightCol, as.integer(aggregationDepth))
new("LogisticRegressionModel", jobj = jobj)
})