aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-10-11 12:41:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-10-11 12:41:35 -0700
commit23405f324a8089f86ebcbede9bb32944137508e8 (patch)
tree4f413126dada2746cbdfd75457eb1a8a765f12f8 /R
parent07508bd01d16f3331be167ff92770d19c8b1f46a (diff)
downloadspark-23405f324a8089f86ebcbede9bb32944137508e8.tar.gz
spark-23405f324a8089f86ebcbede9bb32944137508e8.tar.bz2
spark-23405f324a8089f86ebcbede9bb32944137508e8.zip
[SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type
## What changes were proposed in this pull request? Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type. See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153). ## How was this patch tested? Add unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15431 from yanboliang/spark-15153-2.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R10
1 files changed, 10 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index a1eaaf2091..c99315726a 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -481,6 +481,16 @@ test_that("spark.naiveBayes", {
expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
+
+ # Test numeric response variable
+ t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1)
+ t2 <- t1[-4]
+ df <- suppressWarnings(createDataFrame(t2))
+ m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0)
+ s <- summary(m)
+ expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
+ expect_equal(sum(s$apriori), 1)
+ expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
})
test_that("spark.survreg", {