diff options
author | Junyang Qian <junyangq@databricks.com> | 2016-08-19 14:24:09 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-08-19 14:24:09 -0700 |
commit | acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8 (patch) | |
tree | bf01165da59ed904073196844195484318459d81 /R/pkg/inst/tests/testthat | |
parent | cf0cce90364d17afe780ff9a5426dfcefa298535 (diff) | |
download | spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.gz spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.bz2 spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.zip |
[SPARK-16443][SPARKR] Alternating Least Squares (ALS) wrapper
## What changes were proposed in this pull request?
Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated.
## How was this patch tested?
SparkR unit tests.
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
![screen shot 2016-07-27 at 3 50 31 pm](https://cloud.githubusercontent.com/assets/15318264/17195347/f7a6352a-5411-11e6-8e21-61a48070192a.png)
![screen shot 2016-07-27 at 3 50 46 pm](https://cloud.githubusercontent.com/assets/15318264/17195348/f7a7d452-5411-11e6-845f-6d292283bc28.png)
Author: Junyang Qian <junyangq@databricks.com>
Closes #14384 from junyangq/SPARK-16443.
Diffstat (limited to 'R/pkg/inst/tests/testthat')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index dfb7a185cd..67a3099101 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", { expect_equal(length(local.posterior), sum(unlist(local.posterior))) }) +test_that("spark.als", { + data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), + list(2, 1, 1.0), list(2, 2, 5.0)) + df <- createDataFrame(data, c("user", "item", "score")) + model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", + rank = 10, maxIter = 5, seed = 0, reg = 0.1) + stats <- summary(model) + expect_equal(stats$rank, 10) + test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) + predictions <- collect(predict(model, test)) + + expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), + tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) + + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + + unlink(modelPath) +}) + sparkR.session.stop() |