aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-11-13 20:25:12 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-13 20:25:12 -0800
commit07be232ea12dfc8dc3701ca948814be7dbebf4ee (patch)
tree856a17d84397a6205c6cc5c83c25c8cdfe09d21b /R/pkg
parentb91a51bb231af321860415075a7f404bc46e0a74 (diff)
downloadspark-07be232ea12dfc8dc3701ca948814be7dbebf4ee.tar.gz
spark-07be232ea12dfc8dc3701ca948814be7dbebf4ee.tar.bz2
spark-07be232ea12dfc8dc3701ca948814be7dbebf4ee.zip
[SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data
## What changes were proposed in this pull request? * Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data. ``` java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true. ``` See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug. * Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function. * Drop some unwanted columns when making prediction. ## How was this patch tested? Add unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15851 from yanboliang/spark-18412.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R18
1 files changed, 15 insertions, 3 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b76f75dbdc..07df4b6d6f 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -881,7 +881,8 @@ test_that("spark.kstest", {
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
})
-test_that("spark.randomForest Regression", {
+test_that("spark.randomForest", {
+ # regression
data <- suppressWarnings(createDataFrame(longley))
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 1)
@@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", {
expect_equal(stats$treeWeights, stats2$treeWeights)
unlink(modelPath)
-})
-test_that("spark.randomForest Classification", {
+ # classification
data <- suppressWarnings(createDataFrame(iris))
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
@@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", {
predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)
+
+ # spark.randomForest classification can work on libsvm data
+ data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
+ source = "libsvm")
+ model <- spark.randomForest(data, label ~ features, "classification")
+ expect_equal(summary(model)$numFeatures, 4)
})
test_that("spark.gbt", {
@@ -1039,6 +1045,12 @@ test_that("spark.gbt", {
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
expect_equal(s$numFeatures, 5)
expect_equal(s$numTrees, 20)
+
+ # spark.gbt classification can work on libsvm data
+ data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
+ source = "libsvm")
+ model <- spark.gbt(data, label ~ features, "classification")
+ expect_equal(summary(model)$numFeatures, 692)
})
sparkR.session.stop()