aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R40
1 files changed, 40 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index dfb7a185cd..67a3099101 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", {
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
})
+test_that("spark.als", {
+ data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+ list(2, 1, 1.0), list(2, 2, 5.0))
+ df <- createDataFrame(data, c("user", "item", "score"))
+ model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
+ rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+ stats <- summary(model)
+ expect_equal(stats$rank, 10)
+ test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
+ predictions <- collect(predict(model, test))
+
+ expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409),
+ tolerance = 1e-4)
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats2$rating, "score")
+ userFactors <- collect(stats$userFactors)
+ itemFactors <- collect(stats$itemFactors)
+ userFactors2 <- collect(stats2$userFactors)
+ itemFactors2 <- collect(stats2$itemFactors)
+
+ orderUser <- order(userFactors$id)
+ orderUser2 <- order(userFactors2$id)
+ expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+ expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
+
+ orderItem <- order(itemFactors$id)
+ orderItem2 <- order(itemFactors2$id)
+ expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+ expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()