diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-08-18 05:33:52 -0700 |
---|---|---|
committer | Felix Cheung <felixcheung@apache.org> | 2016-08-18 05:33:52 -0700 |
commit | b72bb62d421840f82d663c6b8e3922bd14383fbb (patch) | |
tree | 1445a4e605794d84a606661dcfbd68decb3df657 /R/pkg/inst/tests/testthat/test_mllib.R | |
parent | 68f5087d2107d6afec5d5745f0cb0e9e3bdd6a0b (diff) | |
download | spark-b72bb62d421840f82d663c6b8e3922bd14383fbb.tar.gz spark-b72bb62d421840f82d663c6b8e3922bd14383fbb.tar.bz2 spark-b72bb62d421840f82d663c6b8e3922bd14383fbb.zip |
[SPARK-16447][ML][SPARKR] LDA wrapper in SparkR
## What changes were proposed in this pull request?
Add LDA Wrapper in SparkR with the following interfaces:
- spark.lda(data, ...)
- spark.posterior(object, newData, ...)
- spark.perplexity(object, ...)
- summary(object)
- write.ml(object)
- read.ml(path)
## How was this patch tested?
Test with SparkR unit test.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #14229 from yinxusen/SPARK-16447.
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib.R')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 96179864a8..8c380fbf15 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", { unlink(modelPath) }) +test_that("spark.lda with libsvm", { + text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") + model <- spark.lda(text, optimizer = "em") + + stats <- summary(model, 10) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 11) + expect_true(is.null(vocabulary)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + + unlink(modelPath) +}) + +test_that("spark.lda with text input", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, optimizer = "online", features = "value") + + stats <- summary(model) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 10) + expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_true(all.equal(vocabulary, stats2$vocabulary)) + + unlink(modelPath) +}) + +test_that("spark.posterior and spark.perplexity", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, features = "value", k = 3) + + # Assert perplexities are equal + stats <- summary(model) + logPerplexity <- spark.perplexity(model, text) + expect_equal(logPerplexity, stats$logPerplexity) + + # Assert the sum of every topic distribution is equal to 1 + posterior <- spark.posterior(model, text) + local.posterior <- collect(posterior)$topicDistribution + expect_equal(length(local.posterior), sum(unlist(local.posterior))) +}) + sparkR.session.stop() |