aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-12-07 20:23:28 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-12-07 20:23:28 -0800
commit97255497d885f0f8ccfc808e868bc8aa5e4d1063 (patch)
tree5af53d4a575e9f073e6d22f681d0b370eacd2e84
parent82253617f5b3cdbd418c48f94e748651ee80077e (diff)
downloadspark-97255497d885f0f8ccfc808e868bc8aa5e4d1063.tar.gz
spark-97255497d885f0f8ccfc808e868bc8aa5e4d1063.tar.bz2
spark-97255497d885f0f8ccfc808e868bc8aa5e4d1063.zip
[SPARK-18326][SPARKR][ML] Review SparkR ML wrappers API for 2.1
## What changes were proposed in this pull request? Reviewing SparkR ML wrappers API for 2.1 release, mainly two issues: * Remove ```probabilityCol``` from the argument list of ```spark.logit``` and ```spark.randomForest```. Since it was used when making prediction and should be an argument of ```predict```, and we will work on this at [SPARK-18618](https://issues.apache.org/jira/browse/SPARK-18618) in the next release cycle. * Fix ```spark.als``` params to make it consistent with MLlib. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16169 from yanboliang/spark-18326.
-rw-r--r--R/pkg/R/mllib.R23
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala2
4 files changed, 13 insertions, 20 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 074e9cbebe..632e4add64 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -733,7 +733,6 @@ setMethod("predict", signature(object = "KMeansModel"),
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
-#' @param probabilityCol column name for predicted class conditional probabilities.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model
#' @rdname spark.logit
@@ -772,7 +771,7 @@ setMethod("predict", signature(object = "KMeansModel"),
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
- thresholds = 0.5, weightCol = NULL, probabilityCol = "probability") {
+ thresholds = 0.5, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")
if (is.null(weightCol)) {
@@ -784,7 +783,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula")
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
- as.character(weightCol), as.character(probabilityCol))
+ as.character(weightCol))
new("LogisticRegressionModel", jobj = jobj)
})
@@ -1425,7 +1424,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
#' @param rank rank of the matrix factorization (> 0).
-#' @param reg regularization parameter (>= 0).
+#' @param regParam regularization parameter (>= 0).
#' @param maxIter maximum number of iterations (>= 0).
#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
#' @param implicitPrefs logical value indicating whether to use implicit preference.
@@ -1464,21 +1463,21 @@ setMethod("predict", signature(object = "GaussianMixtureModel"),
#'
#' # set other arguments
#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
-#' reg = 0.1, nonnegative = TRUE)
+#' regParam = 0.1, nonnegative = TRUE)
#' statsS <- summary(modelS)
#' }
#' @note spark.als since 2.1.0
setMethod("spark.als", signature(data = "SparkDataFrame"),
function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
- rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE,
+ rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE,
implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
checkpointInterval = 10, seed = 0) {
if (!is.numeric(rank) || rank <= 0) {
stop("rank should be a positive number.")
}
- if (!is.numeric(reg) || reg < 0) {
- stop("reg should be a nonnegative number.")
+ if (!is.numeric(regParam) || regParam < 0) {
+ stop("regParam should be a nonnegative number.")
}
if (!is.numeric(maxIter) || maxIter <= 0) {
stop("maxIter should be a positive number.")
@@ -1486,7 +1485,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
"fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
- reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
+ regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
as.integer(numUserBlocks), as.integer(numItemBlocks),
as.integer(checkpointInterval), as.integer(seed))
new("ALSModel", jobj = jobj)
@@ -1684,8 +1683,6 @@ print.summary.KSTest <- function(x, ...) {
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
-#' @param probabilityCol column name for predicted class conditional probabilities, only for
-#' classification.
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -1720,7 +1717,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
- maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") {
+ maxMemoryInMB = 256, cacheNodeIds = FALSE) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@@ -1749,7 +1746,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
impurity, as.integer(minInstancesPerNode),
as.numeric(minInfoGain), as.integer(checkpointInterval),
as.character(featureSubsetStrategy), seed,
- as.numeric(subsamplingRate), as.character(probabilityCol),
+ as.numeric(subsamplingRate),
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("RandomForestClassificationModel", jobj = jobj)
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 4758e40e41..53833ee2f3 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -926,10 +926,10 @@ test_that("spark.posterior and spark.perplexity", {
test_that("spark.als", {
data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
- list(2, 1, 1.0), list(2, 2, 5.0))
+ list(2, 1, 1.0), list(2, 2, 5.0))
df <- createDataFrame(data, c("user", "item", "score"))
model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
- rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+ rank = 10, maxIter = 5, seed = 0, regParam = 0.1)
stats <- summary(model)
expect_equal(stats$rank, 10)
test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 7f0f3cea21..645bc7247f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -96,8 +96,7 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
- weightCol: String,
- probabilityCol: String
+ weightCol: String
): LogisticRegressionWrapper = {
val rFormula = new RFormula()
@@ -123,7 +122,6 @@ private[r] object LogisticRegressionWrapper
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
- .setProbabilityCol(probabilityCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (thresholds.length > 1) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 0b860e5af9..366f375b58 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -76,7 +76,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
featureSubsetStrategy: String,
seed: String,
subsamplingRate: Double,
- probabilityCol: String,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
@@ -102,7 +101,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.setSubsamplingRate(subsamplingRate)
.setMaxMemoryInMB(maxMemoryInMB)
.setCacheNodeIds(cacheNodeIds)
- .setProbabilityCol(probabilityCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)