aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R15
1 files changed, 15 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index c99315726a..33cc069f14 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -410,6 +410,21 @@ test_that("spark.mlp", {
model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10)
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1))
+
+ # test initialWeights
+ model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights =
+ c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9))
+ mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
+ expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1))
+
+ model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights =
+ c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0))
+ mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
+ expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1))
+
+ model <- spark.mlp(df, layers = c(4, 3), maxIter = 2)
+ mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
+ expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1))
})
test_that("spark.naiveBayes", {