From 1f86e795b87ba93640062f29e87a032924d94b2a Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 22 Feb 2017 11:50:24 -0800 Subject: [SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs ## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #16945 from wangmiao1981/svc. --- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib_classification.R | 13 ++++++++---- R/pkg/R/mllib_regression.R | 24 ++++++++++++++-------- .../tests/testthat/test_mllib_classification.R | 10 ++++++++- 4 files changed, 35 insertions(+), 14 deletions(-) (limited to 'R') diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 11940d3560..647cbbdd82 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa..05bb952661 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 96ee220bc4..ac0578c4ab 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) + tol, as.integer(maxIter), weightCol, regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) + weightCol) new("IsotonicRegressionModel", jobj = jobj) }) @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { + function(data, formula, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) + "fit", formula, data@sdf, as.integer(aggregationDepth)) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 620f528f2e..459254d271 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -211,7 +211,15 @@ test_that("spark.logit", { df <- createDataFrame(data) model <- spark.logit(df, label ~ feature) prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) }) test_that("spark.mlp", { -- cgit v1.2.3