aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib.R
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2016-11-08 16:00:45 -0800
committerFelix Cheung <felixcheung@apache.org>2016-11-08 16:00:45 -0800
commit55964c15a7b639f920dfe6c104ae4fdcd673705c (patch)
tree1e551bd8c155145135acc161f711e0464b053f8c /R/pkg/inst/tests/testthat/test_mllib.R
parent6f7ecb0f2975d24a71e4240cf623f5bd8992bbeb (diff)
downloadspark-55964c15a7b639f920dfe6c104ae4fdcd673705c.tar.gz
spark-55964c15a7b639f920dfe6c104ae4fdcd673705c.tar.bz2
spark-55964c15a7b639f920dfe6c104ae4fdcd673705c.zip
[SPARK-18239][SPARKR] Gradient Boosted Tree for R
## What changes were proposed in this pull request? Gradient Boosted Tree in R. With a few minor improvements to RandomForest in R. Since this is relatively isolated I'd like to target this for branch-2.1 ## How was this patch tested? manual tests, unit tests Author: Felix Cheung <felixcheung_m@hotmail.com> Closes #15746 from felixcheung/rgbt.
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R68
1 files changed, 68 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 5f742d9045..33e9d0d267 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -949,4 +949,72 @@ test_that("spark.randomForest Classification", {
unlink(modelPath)
})
+test_that("spark.gbt", {
+ # regression
+ data <- suppressWarnings(createDataFrame(longley))
+ model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123)
+ predictions <- collect(predict(model, data))
+ expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
+ 63.221, 63.639, 64.989, 63.761,
+ 66.019, 67.857, 68.169, 66.513,
+ 68.655, 69.564, 69.331, 70.551),
+ tolerance = 1e-4)
+ stats <- summary(model)
+ expect_equal(stats$numTrees, 20)
+ expect_equal(stats$formula, "Employed ~ .")
+ expect_equal(stats$numFeatures, 6)
+ expect_equal(length(stats$treeWeights), 20)
+
+ modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$formula, stats2$formula)
+ expect_equal(stats$numFeatures, stats2$numFeatures)
+ expect_equal(stats$features, stats2$features)
+ expect_equal(stats$featureImportances, stats2$featureImportances)
+ expect_equal(stats$numTrees, stats2$numTrees)
+ expect_equal(stats$treeWeights, stats2$treeWeights)
+
+ unlink(modelPath)
+
+ # classification
+ # label must be binary - GBTClassifier currently only supports binary classification.
+ iris2 <- iris[iris$Species != "virginica", ]
+ data <- suppressWarnings(createDataFrame(iris2))
+ model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
+ stats <- summary(model)
+ expect_equal(stats$numFeatures, 2)
+ expect_equal(stats$numTrees, 20)
+ expect_error(capture.output(stats), NA)
+ expect_true(length(capture.output(stats)) > 6)
+ predictions <- collect(predict(model, data))$prediction
+ # test string prediction values
+ expect_equal(length(grep("setosa", predictions)), 50)
+ expect_equal(length(grep("versicolor", predictions)), 50)
+
+ modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$depth, stats2$depth)
+ expect_equal(stats$numNodes, stats2$numNodes)
+ expect_equal(stats$numClasses, stats2$numClasses)
+
+ unlink(modelPath)
+
+ iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
+ df <- suppressWarnings(createDataFrame(iris2))
+ m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
+ s <- summary(m)
+ # test numeric prediction values
+ expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
+ expect_equal(s$numFeatures, 5)
+ expect_equal(s$numTrees, 20)
+})
+
sparkR.session.stop()