diff options
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 96179864a8..8c380fbf15 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", { unlink(modelPath) }) +test_that("spark.lda with libsvm", { + text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") + model <- spark.lda(text, optimizer = "em") + + stats <- summary(model, 10) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 11) + expect_true(is.null(vocabulary)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + + unlink(modelPath) +}) + +test_that("spark.lda with text input", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, optimizer = "online", features = "value") + + stats <- summary(model) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 10) + expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_true(all.equal(vocabulary, stats2$vocabulary)) + + unlink(modelPath) +}) + +test_that("spark.posterior and spark.perplexity", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, features = "value", k = 3) + + # Assert perplexities are equal + stats <- summary(model) + logPerplexity <- spark.perplexity(model, text) + expect_equal(logPerplexity, stats$logPerplexity) + + # Assert the sum of every topic distribution is equal to 1 + posterior <- spark.posterior(model, text) + local.posterior <- collect(posterior)$topicDistribution + expect_equal(length(local.posterior), sum(unlist(local.posterior))) +}) + sparkR.session.stop() |