diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-11-10 17:13:10 -0800 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-11-10 17:13:10 -0800 |
commit | 5ddf69470b93c0b8a28bb4ac905e7670d9c50a95 (patch) | |
tree | a6f7eff240d2f1f299138bce167e2599634aad83 /R/pkg/inst/tests | |
parent | a3356343cbf58b930326f45721fb4ecade6f8029 (diff) | |
download | spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.gz spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.tar.bz2 spark-5ddf69470b93c0b8a28bb4ac905e7670d9c50a95.zip |
[SPARK-18401][SPARKR][ML] SparkR random forest should support output original label.
## What changes were proposed in this pull request?
SparkR ```spark.randomForest``` classification prediction should output original label rather than the indexed label. This issue is very similar with [SPARK-18291](https://issues.apache.org/jira/browse/SPARK-18291).
## How was this patch tested?
Add unit tests.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #15842 from yanboliang/spark-18401.
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 33e9d0d267..b76f75dbdc 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -935,6 +935,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -947,6 +951,26 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) }) test_that("spark.gbt", { |