diff options
author | wm624@hotmail.com <wm624@hotmail.com> | 2017-02-22 11:50:24 -0800 |
---|---|---|
committer | Felix Cheung <felixcheung@apache.org> | 2017-02-22 11:50:24 -0800 |
commit | 1f86e795b87ba93640062f29e87a032924d94b2a (patch) | |
tree | 5489ebff9dd4106fd5f508365d19ff60dccf162b /R/pkg/R/mllib_classification.R | |
parent | e4065376d2b4eec178a119476fa95b26f440c076 (diff) | |
download | spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.gz spark-1f86e795b87ba93640062f29e87a032924d94b2a.tar.bz2 spark-1f86e795b87ba93640062f29e87a032924d94b2a.zip |
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs
## What changes were proposed in this pull request?
This is a follow-up PR of #16800
When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter.
## How was this patch tested?
Existing tests.
Author: wm624@hotmail.com <wm624@hotmail.com>
Closes #16945 from wangmiao1981/svc.
Diffstat (limited to 'R/pkg/R/mllib_classification.R')
-rw-r--r-- | R/pkg/R/mllib_classification.R | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa..05bb952661 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) |