aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2016-08-24 11:18:10 -0700
committerFelix Cheung <felixcheung@apache.org>2016-08-24 11:18:10 -0700
commit2fbdb606392631b1dff88ec86f388cc2559c28f5 (patch)
tree002050c92864378c0c65a5d6c449420c8d604170 /R/pkg/inst
parentd2932a0e987132c694ed59515b7c77adaad052e6 (diff)
downloadspark-2fbdb606392631b1dff88ec86f388cc2559c28f5.tar.gz
spark-2fbdb606392631b1dff88ec86f388cc2559c28f5.tar.bz2
spark-2fbdb606392631b1dff88ec86f388cc2559c28f5.zip
[SPARK-16445][MLLIB][SPARKR] Multilayer Perceptron Classifier wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-16445 ## What changes were proposed in this pull request? Create Multilayer Perceptron Classifier wrapper in SparkR ## How was this patch tested? Tested manually on local machine Author: Xin Ren <iamshrek@126.com> Closes #14447 from keypointt/SPARK-16445.
Diffstat (limited to 'R/pkg/inst')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R32
1 files changed, 32 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index de9bd48662..1e6da650d1 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -347,6 +347,38 @@ test_that("spark.kmeans", {
unlink(modelPath)
})
+test_that("spark.mlp", {
+ df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm")
+ model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100,
+ tol = 0.5, stepSize = 1, seed = 1)
+
+ # Test summary method
+ summary <- summary(model)
+ expect_equal(summary$labelCount, 3)
+ expect_equal(summary$layers, c(4, 5, 4, 3))
+ expect_equal(length(summary$weights), 64)
+
+ # Test predict method
+ mlpTestDF <- df
+ mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
+ expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+
+ expect_equal(summary2$labelCount, 3)
+ expect_equal(summary2$layers, c(4, 5, 4, 3))
+ expect_equal(length(summary2$weights), 64)
+
+ unlink(modelPath)
+
+})
+
test_that("spark.naiveBayes", {
# R code to reproduce the result.
# We do not support instance weights yet. So we ignore the frequencies.