From 23405f324a8089f86ebcbede9bb32944137508e8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 11 Oct 2016 12:41:35 -0700 Subject: [SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type ## What changes were proposed in this pull request? Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type. See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153). ## How was this patch tested? Add unit test. Author: Yanbo Liang Closes #15431 from yanboliang/spark-15153-2. --- R/pkg/inst/tests/testthat/test_mllib.R | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'R') diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf2091..c99315726a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -481,6 +481,16 @@ test_that("spark.naiveBayes", { expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) expect_equal(as.character(predict(m, t1[1, ])), "Yes") } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) }) test_that("spark.survreg", { -- cgit v1.2.3