aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R87
1 files changed, 87 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 96179864a8..8c380fbf15 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
unlink(modelPath)
})
+test_that("spark.lda with libsvm", {
+ text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
+ model <- spark.lda(text, optimizer = "em")
+
+ stats <- summary(model, 10)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 11)
+ expect_true(is.null(vocabulary))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_equal(vocabulary, stats2$vocabulary)
+
+ unlink(modelPath)
+})
+
+test_that("spark.lda with text input", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, optimizer = "online", features = "value")
+
+ stats <- summary(model)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 10)
+ expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_true(all.equal(vocabulary, stats2$vocabulary))
+
+ unlink(modelPath)
+})
+
+test_that("spark.posterior and spark.perplexity", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, features = "value", k = 3)
+
+ # Assert perplexities are equal
+ stats <- summary(model)
+ logPerplexity <- spark.perplexity(model, text)
+ expect_equal(logPerplexity, stats$logPerplexity)
+
+ # Assert the sum of every topic distribution is equal to 1
+ posterior <- spark.posterior(model, text)
+ local.posterior <- collect(posterior)$topicDistribution
+ expect_equal(length(local.posterior), sum(unlist(local.posterior)))
+})
+
sparkR.session.stop()