aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2016-10-30 16:19:19 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-30 16:19:19 -0700
commitb6879b8b3518c71c23262554fcb0fdad60287011 (patch)
treefeaa0cdde6aee163533cd083f4fcbc518b4f3b20 /R/pkg
parent2881a2d1d1a650a91df2c6a01275eba14a43b42a (diff)
downloadspark-b6879b8b3518c71c23262554fcb0fdad60287011.tar.gz
spark-b6879b8b3518c71c23262554fcb0fdad60287011.tar.bz2
spark-b6879b8b3518c71c23262554fcb0fdad60287011.zip
[SPARK-16137][SPARKR] randomForest for R
## What changes were proposed in this pull request? Random Forest Regression and Classification for R Clean-up/reordering generics.R ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15607 from felixcheung/rrandomforest.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/NAMESPACE9
-rw-r--r--R/pkg/R/generics.R66
-rw-r--r--R/pkg/R/mllib.R252
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R68
4 files changed, 361 insertions, 34 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7a89c01fee..9cd6269f9a 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -44,7 +44,8 @@ exportMethods("glm",
"spark.gaussianMixture",
"spark.als",
"spark.kstest",
- "spark.logit")
+ "spark.logit",
+ "spark.randomForest")
# Job group lifecycle management methods
export("setJobGroup",
@@ -350,7 +351,9 @@ export("as.DataFrame",
"uncacheTable",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml",
- "print.summary.KSTest")
+ "print.summary.KSTest",
+ "print.summary.RandomForestRegressionModel",
+ "print.summary.RandomForestClassificationModel")
export("structField",
"structField.jobj",
@@ -375,6 +378,8 @@ S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
+S3method(print, summary.RandomForestRegressionModel)
+S3method(print, summary.RandomForestClassificationModel)
S3method(structField, character)
S3method(structField, jobj)
S3method(structType, jobj)
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 107e1c638b..0271b26a10 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1310,9 +1310,11 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
#' @export
setGeneric("year", function(x) { standardGeneric("year") })
-#' @rdname spark.glm
+###################### Spark.ML Methods ##########################
+
+#' @rdname fitted
#' @export
-setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
+setGeneric("fitted")
#' @param x,y For \code{glm}: logical values indicating whether the response vector
#' and model matrix used in the fitting process should be returned as
@@ -1332,13 +1334,38 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
#' @export
setGeneric("rbind", signature = "...")
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
+
+#' @rdname spark.gaussianMixture
+#' @export
+setGeneric("spark.gaussianMixture",
+ function(data, formula, ...) { standardGeneric("spark.gaussianMixture") })
+
+#' @rdname spark.glm
+#' @export
+setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
+
+#' @rdname spark.isoreg
+#' @export
+setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
+
#' @rdname spark.kmeans
#' @export
setGeneric("spark.kmeans", function(data, formula, ...) { standardGeneric("spark.kmeans") })
-#' @rdname fitted
+#' @rdname spark.kstest
#' @export
-setGeneric("fitted")
+setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+
+#' @rdname spark.logit
+#' @export
+setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
#' @rdname spark.mlp
#' @export
@@ -1348,13 +1375,14 @@ setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") })
#' @export
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
-#' @rdname spark.survreg
+#' @rdname spark.randomForest
#' @export
-setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
+setGeneric("spark.randomForest",
+ function(data, formula, ...) { standardGeneric("spark.randomForest") })
-#' @rdname spark.lda
+#' @rdname spark.survreg
#' @export
-setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
#' @rdname spark.lda
#' @export
@@ -1364,20 +1392,6 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark
#' @export
setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
-#' @rdname spark.isoreg
-#' @export
-setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
-
-#' @rdname spark.gaussianMixture
-#' @export
-setGeneric("spark.gaussianMixture",
- function(data, formula, ...) {
- standardGeneric("spark.gaussianMixture")
- })
-
-#' @rdname spark.logit
-#' @export
-setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.logit") })
#' @param object a fitted ML model object.
#' @param path the directory where the model is saved.
@@ -1385,11 +1399,3 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark.
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
-
-#' @rdname spark.als
-#' @export
-setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
-
-#' @rdname spark.kstest
-#' @export
-setGeneric("spark.kstest", function(data, ...) { standardGeneric("spark.kstest") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 629f284b79..7a220b8d53 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -102,6 +102,20 @@ setClass("KSTest", representation(jobj = "jobj"))
#' @note LogisticRegressionModel since 2.1.0
setClass("LogisticRegressionModel", representation(jobj = "jobj"))
+#' S4 class that represents a RandomForestRegressionModel
+#'
+#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel
+#' @export
+#' @note RandomForestRegressionModel since 2.1.0
+setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
+
+#' S4 class that represents a RandomForestClassificationModel
+#'
+#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel
+#' @export
+#' @note RandomForestClassificationModel since 2.1.0
+setClass("RandomForestClassificationModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -112,7 +126,7 @@ setClass("LogisticRegressionModel", representation(jobj = "jobj"))
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
-#' @seealso \link{spark.survreg}
+#' @seealso \link{spark.randomForest}, \link{spark.survreg},
#' @seealso \link{read.ml}
NULL
@@ -125,7 +139,8 @@ NULL
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
-#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
+#' @seealso \link{spark.randomForest}, \link{spark.survreg}
NULL
write_internal <- function(object, path, overwrite = FALSE) {
@@ -1122,6 +1137,10 @@ read.ml <- function(path) {
new("ALSModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) {
new("LogisticRegressionModel", jobj = jobj)
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) {
+ new("RandomForestRegressionModel", jobj = jobj)
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
+ new("RandomForestClassificationModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
@@ -1617,3 +1636,232 @@ print.summary.KSTest <- function(x, ...) {
cat(summaryStr, "\n")
invisible(x)
}
+
+#' Random Forest Model for Regression and Classification
+#'
+#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on
+#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest
+#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
+#' save/load fitted models.
+#' For more details, see
+#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest}
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', ':', '+', and '-'.
+#' @param type type of model, one of "regression" or "classification", to fit
+#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5)
+#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
+#' how to split on features at each node. More bins give higher granularity. Must be
+#' >= 2 and >= number of categories in any categorical feature. (default = 32)
+#' @param numTrees Number of trees to train (>= 1).
+#' @param impurity Criterion used for information gain calculation.
+#' For regression, must be "variance". For classification, must be one of
+#' "entropy" and "gini". (default = gini)
+#' @param minInstancesPerNode Minimum number of instances each child must have after split.
+#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
+#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
+#' @param featureSubsetStrategy The number of features to consider for splits at each tree node.
+#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n].
+#' @param seed integer seed for random number generation.
+#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in
+#' range (0, 1]. (default = 1.0)
+#' @param probabilityCol column name for predicted class conditional probabilities, only for
+#' classification. (default = "probability")
+#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
+#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
+#' nodes.
+#' @param ... additional arguments passed to the method.
+#' @aliases spark.randomForest,SparkDataFrame,formula-method
+#' @return \code{spark.randomForest} returns a fitted Random Forest model.
+#' @rdname spark.randomForest
+#' @name spark.randomForest
+#' @export
+#' @examples
+#' \dontrun{
+#' # fit a Random Forest Regression Model
+#' df <- createDataFrame(longley)
+#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
+#'
+#' # get the summary of the model
+#' summary(model)
+#'
+#' # make predictions
+#' predictions <- predict(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # fit a Random Forest Classification Model
+#' df <- createDataFrame(iris)
+#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification")
+#' }
+#' @note spark.randomForest since 2.1.0
+setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, type = c("regression", "classification"),
+ maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
+ minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
+ featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
+ probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+ type <- match.arg(type)
+ formula <- paste(deparse(formula), collapse = "")
+ if (!is.null(seed)) {
+ seed <- as.character(as.integer(seed))
+ }
+ switch(type,
+ regression = {
+ if (is.null(impurity)) impurity <- "variance"
+ impurity <- match.arg(impurity, "variance")
+ jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper",
+ "fit", data@sdf, formula, as.integer(maxDepth),
+ as.integer(maxBins), as.integer(numTrees),
+ impurity, as.integer(minInstancesPerNode),
+ as.numeric(minInfoGain), as.integer(checkpointInterval),
+ as.character(featureSubsetStrategy), seed,
+ as.numeric(subsamplingRate),
+ as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+ new("RandomForestRegressionModel", jobj = jobj)
+ },
+ classification = {
+ if (is.null(impurity)) impurity <- "gini"
+ impurity <- match.arg(impurity, c("gini", "entropy"))
+ jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
+ "fit", data@sdf, formula, as.integer(maxDepth),
+ as.integer(maxBins), as.integer(numTrees),
+ impurity, as.integer(minInstancesPerNode),
+ as.numeric(minInfoGain), as.integer(checkpointInterval),
+ as.character(featureSubsetStrategy), seed,
+ as.numeric(subsamplingRate), as.character(probabilityCol),
+ as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+ new("RandomForestClassificationModel", jobj = jobj)
+ }
+ )
+ })
+
+# Makes predictions from a Random Forest Regression model or Classification model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
+#' "prediction"
+#' @rdname spark.randomForest
+#' @aliases predict,RandomForestRegressionModel-method
+#' @export
+#' @note predict(randomForestRegressionModel) since 2.1.0
+setMethod("predict", signature(object = "RandomForestRegressionModel"),
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
+
+#' @rdname spark.randomForest
+#' @aliases predict,RandomForestClassificationModel-method
+#' @export
+#' @note predict(randomForestClassificationModel) since 2.1.0
+setMethod("predict", signature(object = "RandomForestClassificationModel"),
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
+
+# Save the Random Forest Regression or Classification model to the input path.
+
+#' @param object A fitted Random Forest regression model or classification model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,RandomForestRegressionModel,character-method
+#' @rdname spark.randomForest
+#' @export
+#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
+
+#' @aliases write.ml,RandomForestClassificationModel,character-method
+#' @rdname spark.randomForest
+#' @export
+#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
+
+# Get the summary of an RandomForestRegressionModel model
+summary.randomForest <- function(model) {
+ jobj <- model@jobj
+ formula <- callJMethod(jobj, "formula")
+ numFeatures <- callJMethod(jobj, "numFeatures")
+ features <- callJMethod(jobj, "features")
+ featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
+ numTrees <- callJMethod(jobj, "numTrees")
+ treeWeights <- callJMethod(jobj, "treeWeights")
+ list(formula = formula,
+ numFeatures = numFeatures,
+ features = features,
+ featureImportances = featureImportances,
+ numTrees = numTrees,
+ treeWeights = treeWeights,
+ jobj = jobj)
+}
+
+#' @return \code{summary} returns the model's features as lists, depth and number of nodes
+#' or number of classes.
+#' @rdname spark.randomForest
+#' @aliases summary,RandomForestRegressionModel-method
+#' @export
+#' @note summary(RandomForestRegressionModel) since 2.1.0
+setMethod("summary", signature(object = "RandomForestRegressionModel"),
+ function(object) {
+ ans <- summary.randomForest(object)
+ class(ans) <- "summary.RandomForestRegressionModel"
+ ans
+ })
+
+# Get the summary of an RandomForestClassificationModel model
+
+#' @rdname spark.randomForest
+#' @aliases summary,RandomForestClassificationModel-method
+#' @export
+#' @note summary(RandomForestClassificationModel) since 2.1.0
+setMethod("summary", signature(object = "RandomForestClassificationModel"),
+ function(object) {
+ ans <- summary.randomForest(object)
+ class(ans) <- "summary.RandomForestClassificationModel"
+ ans
+ })
+
+# Prints the summary of Random Forest Regression Model
+print.summary.randomForest <- function(x) {
+ jobj <- x$jobj
+ cat("Formula: ", x$formula)
+ cat("\nNumber of features: ", x$numFeatures)
+ cat("\nFeatures: ", unlist(x$features))
+ cat("\nFeature importances: ", x$featureImportances)
+ cat("\nNumber of trees: ", x$numTrees)
+ cat("\nTree weights: ", unlist(x$treeWeights))
+
+ summaryStr <- callJMethod(jobj, "summary")
+ cat("\n", summaryStr, "\n")
+ invisible(x)
+}
+
+#' @param x summary object of Random Forest regression model or classification model
+#' returned by \code{summary}.
+#' @rdname spark.randomForest
+#' @export
+#' @note print.summary.RandomForestRegressionModel since 2.1.0
+print.summary.RandomForestRegressionModel <- function(x, ...) {
+ print.summary.randomForest(x)
+}
+
+# Prints the summary of Random Forest Classification Model
+
+#' @rdname spark.randomForest
+#' @export
+#' @note print.summary.RandomForestClassificationModel since 2.1.0
+print.summary.RandomForestClassificationModel <- function(x, ...) {
+ print.summary.randomForest(x)
+}
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 6d1fccc7c0..db98d0e455 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -871,4 +871,72 @@ test_that("spark.kstest", {
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
})
+test_that("spark.randomForest Regression", {
+ data <- suppressWarnings(createDataFrame(longley))
+ model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
+ numTrees = 1)
+
+ predictions <- collect(predict(model, data))
+ expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
+ 63.221, 63.639, 64.989, 63.761,
+ 66.019, 67.857, 68.169, 66.513,
+ 68.655, 69.564, 69.331, 70.551),
+ tolerance = 1e-4)
+
+ stats <- summary(model)
+ expect_equal(stats$numTrees, 1)
+ expect_error(capture.output(stats), NA)
+ expect_true(length(capture.output(stats)) > 6)
+
+ model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
+ numTrees = 20, seed = 123)
+ predictions <- collect(predict(model, data))
+ expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
+ 63.736, 64.296, 64.868, 64.300,
+ 66.709, 67.697, 67.966, 67.252,
+ 68.866, 69.593, 69.195, 69.658),
+ tolerance = 1e-4)
+ stats <- summary(model)
+ expect_equal(stats$numTrees, 20)
+
+ modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$formula, stats2$formula)
+ expect_equal(stats$numFeatures, stats2$numFeatures)
+ expect_equal(stats$features, stats2$features)
+ expect_equal(stats$featureImportances, stats2$featureImportances)
+ expect_equal(stats$numTrees, stats2$numTrees)
+ expect_equal(stats$treeWeights, stats2$treeWeights)
+
+ unlink(modelPath)
+})
+
+test_that("spark.randomForest Classification", {
+ data <- suppressWarnings(createDataFrame(iris))
+ model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
+ maxDepth = 5, maxBins = 16)
+
+ stats <- summary(model)
+ expect_equal(stats$numFeatures, 2)
+ expect_equal(stats$numTrees, 20)
+ expect_error(capture.output(stats), NA)
+ expect_true(length(capture.output(stats)) > 6)
+
+ modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$depth, stats2$depth)
+ expect_equal(stats$numNodes, stats2$numNodes)
+ expect_equal(stats$numClasses, stats2$numClasses)
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()