aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-11-10 17:13:10 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-10 17:13:10 -0800
commit5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 (patch)
treea6f7eff240d2f1f299138bce167e2599634aad83 /R/pkg/inst
parenta3356343cbf58b930326f45721fb4ecade6f8029 (diff)
downloadspark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.gz
spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.bz2
spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.zip
[SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.
## What changes were proposed in this pull request? SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291). ## How was this patch tested? Add unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15842 from yanboliang/spark-18401.
Diffstat (limited to 'R/pkg/inst')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R24
1 files changed, 24 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 33e9d0d267..b76f75dbdc 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", {
expect_equal(stats$numTrees, 20)
expect_error(capture.output(stats), NA)
expect_true(length(capture.output(stats)) > 6)
+ # Test string prediction values
+ predictions <- collect(predict(model, data))$prediction
+ expect_equal(length(grep("setosa", predictions)), 50)
+ expect_equal(length(grep("versicolor", predictions)), 50)
modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp")
write.ml(model, modelPath)
@@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", {
expect_equal(stats$numClasses, stats2$numClasses)
unlink(modelPath)
+
+ # Test numeric response variable
+ labelToIndex <- function(species) {
+ switch(as.character(species),
+ setosa = 0.0,
+ versicolor = 1.0,
+ virginica = 2.0
+ )
+ }
+ iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
+ data <- suppressWarnings(createDataFrame(iris[-5]))
+ model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
+ maxDepth = 5, maxBins = 16)
+ stats <- summary(model)
+ expect_equal(stats$numFeatures, 2)
+ expect_equal(stats$numTrees, 20)
+ # Test numeric prediction values
+ predictions <- collect(predict(model, data))$prediction
+ expect_equal(length(grep("1.0", predictions)), 50)
+ expect_equal(length(grep("2.0", predictions)), 50)
})
test_that("spark.gbt", {