aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2017-03-12 12:15:19 -0700
committerFelix Cheung <felixcheung@apache.org>2017-03-12 12:15:19 -0700
commit9f8ce4825e378b6a856ce65cb9986a5a0f0b624e (patch)
treef36dcc381c02cbfc86dab0e207699eddd9bc87bc /R/pkg
parent2f5187bde1544c452fe5116a2bd243653332a079 (diff)
downloadspark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.tar.gz
spark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.tar.bz2
spark-9f8ce4825e378b6a856ce65cb9986a5a0f0b624e.zip
[SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT Wrapper return param "maxDepth" to R models
## What changes were proposed in this pull request? RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models. Below 4 R wrappers are changed: * `RandomForestClassificationWrapper` * `RandomForestRegressionWrapper` * `GBTClassificationWrapper` * `GBTRegressionWrapper` ## How was this patch tested? Test manually on my local machine. Author: Xin Ren <iamshrek@126.com> Closes #17207 from keypointt/SPARK-19282.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib_tree.R11
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_tree.R10
2 files changed, 17 insertions, 4 deletions
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 40a806c41b..82279be6fb 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) {
numFeatures <- callJMethod(jobj, "numFeatures")
features <- callJMethod(jobj, "features")
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
+ maxDepth <- callJMethod(jobj, "maxDepth")
numTrees <- callJMethod(jobj, "numTrees")
treeWeights <- callJMethod(jobj, "treeWeights")
list(formula = formula,
numFeatures = numFeatures,
features = features,
featureImportances = featureImportances,
+ maxDepth = maxDepth,
numTrees = numTrees,
treeWeights = treeWeights,
jobj = jobj)
@@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) {
cat("\nNumber of features: ", x$numFeatures)
cat("\nFeatures: ", unlist(x$features))
cat("\nFeature importances: ", x$featureImportances)
+ cat("\nMax Depth: ", x$maxDepth)
cat("\nNumber of trees: ", x$numTrees)
cat("\nTree weights: ", unlist(x$treeWeights))
@@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list of components includes \code{formula} (formula),
#' \code{numFeatures} (number of features), \code{features} (list of features),
-#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#' and \code{treeWeights} (tree weights).
+#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
+#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
#' @rdname spark.gbt
#' @aliases summary,GBTRegressionModel-method
#' @export
@@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list of components includes \code{formula} (formula),
#' \code{numFeatures} (number of features), \code{features} (list of features),
-#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#' and \code{treeWeights} (tree weights).
+#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
+#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
#' @rdname spark.randomForest
#' @aliases summary,RandomForestRegressionModel-method
#' @export
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R
index e6fda251eb..e0802a9b02 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R
@@ -39,6 +39,7 @@ test_that("spark.gbt", {
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
expect_equal(stats$formula, "Employed ~ .")
expect_equal(stats$numFeatures, 6)
expect_equal(length(stats$treeWeights), 20)
@@ -53,6 +54,7 @@ test_that("spark.gbt", {
expect_equal(stats$numFeatures, stats2$numFeatures)
expect_equal(stats$features, stats2$features)
expect_equal(stats$featureImportances, stats2$featureImportances)
+ expect_equal(stats$maxDepth, stats2$maxDepth)
expect_equal(stats$numTrees, stats2$numTrees)
expect_equal(stats$treeWeights, stats2$treeWeights)
@@ -66,6 +68,7 @@ test_that("spark.gbt", {
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
predictions <- collect(predict(model, data))$prediction
@@ -93,6 +96,7 @@ test_that("spark.gbt", {
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
expect_equal(s$numFeatures, 5)
expect_equal(s$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
# spark.gbt classification can work on libsvm data
data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
@@ -116,6 +120,7 @@ test_that("spark.randomForest", {
stats <- summary(model)
expect_equal(stats$numTrees, 1)
+ expect_equal(stats$maxDepth, 5)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
@@ -129,6 +134,7 @@ test_that("spark.randomForest", {
tolerance = 1e-4)
stats <- summary(model)
expect_equal(stats$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
write.ml(model, modelPath)
@@ -141,6 +147,7 @@ test_that("spark.randomForest", {
expect_equal(stats$features, stats2$features)
expect_equal(stats$featureImportances, stats2$featureImportances)
expect_equal(stats$numTrees, stats2$numTrees)
+ expect_equal(stats$maxDepth, stats2$maxDepth)
expect_equal(stats$treeWeights, stats2$treeWeights)
unlink(modelPath)
@@ -153,6 +160,7 @@ test_that("spark.randomForest", {
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
# Test string prediction values
@@ -187,6 +195,8 @@ test_that("spark.randomForest", {
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
+ expect_equal(stats$maxDepth, 5)
+
# Test numeric prediction values
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("1.0", predictions)), 50)