aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_mllib_clustering.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_mllib_clustering.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R224
1 files changed, 224 insertions, 0 deletions
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
new file mode 100644
index 0000000000..1980fffd80
--- /dev/null
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -0,0 +1,224 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("MLlib clustering algorithms")
+
+# Tests for MLlib clustering algorithms in SparkR
+sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+
+absoluteSparkPath <- function(x) {
+ sparkHome <- sparkR.conf("spark.home")
+ file.path(sparkHome, x)
+}
+
+test_that("spark.gaussianMixture", {
+ # R code to reproduce the result.
+ # nolint start
+ #' library(mvtnorm)
+ #' set.seed(1)
+ #' a <- rmvnorm(7, c(0, 0))
+ #' b <- rmvnorm(8, c(10, 10))
+ #' data <- rbind(a, b)
+ #' model <- mvnormalmixEM(data, k = 2)
+ #' model$lambda
+ #
+ # [1] 0.4666667 0.5333333
+ #
+ #' model$mu
+ #
+ # [1] 0.11731091 -0.06192351
+ # [1] 10.363673 9.897081
+ #
+ #' model$sigma
+ #
+ # [[1]]
+ # [,1] [,2]
+ # [1,] 0.62049934 0.06880802
+ # [2,] 0.06880802 1.27431874
+ #
+ # [[2]]
+ # [,1] [,2]
+ # [1,] 0.2961543 0.160783
+ # [2,] 0.1607830 1.008878
+ # nolint end
+ data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808),
+ list(0.3295078, -0.8204684), list(0.4874291, 0.7383247),
+ list(0.5757814, -0.3053884), list(1.5117812, 0.3898432),
+ list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664),
+ list(9.9838097, 10.9438362), list(10.8212212, 10.5939013),
+ list(10.9189774, 10.7821363), list(10.0745650, 8.0106483),
+ list(10.6198257, 9.9438713), list(9.8442045, 8.5292476),
+ list(9.5218499, 10.4179416))
+ df <- createDataFrame(data, c("x1", "x2"))
+ model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
+ stats <- summary(model)
+ rLambda <- c(0.4666667, 0.5333333)
+ rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081)
+ rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874,
+ 0.2961543, 0.160783, 0.1607830, 1.008878)
+ expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
+ expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
+ expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+ p <- collect(select(predict(model, df), "prediction"))
+ expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats$lambda, stats2$lambda)
+ expect_equal(unlist(stats$mu), unlist(stats2$mu))
+ expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+
+ unlink(modelPath)
+})
+
+test_that("spark.kmeans", {
+ newIris <- iris
+ newIris$Species <- NULL
+ training <- suppressWarnings(createDataFrame(newIris))
+
+ take(training, 1)
+
+ model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random")
+ sample <- take(select(predict(model, training), "prediction"), 1)
+ expect_equal(typeof(sample$prediction), "integer")
+ expect_equal(sample$prediction, 1)
+
+ # Test stats::kmeans is working
+ statsModel <- kmeans(x = newIris, centers = 2)
+ expect_equal(sort(unique(statsModel$cluster)), c(1, 2))
+
+ # Test fitted works on KMeans
+ fitted.model <- fitted(model)
+ expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
+
+ # Test summary works on KMeans
+ summary.model <- summary(model)
+ cluster <- summary.model$cluster
+ k <- summary.model$k
+ expect_equal(k, 2)
+ expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ summary2 <- summary(model2)
+ expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size)))
+ expect_equal(summary.model$coefficients, summary2$coefficients)
+ expect_true(!summary.model$is.loaded)
+ expect_true(summary2$is.loaded)
+
+ unlink(modelPath)
+})
+
+test_that("spark.lda with libsvm", {
+ text <- read.df(absoluteSparkPath("data/mllib/sample_lda_libsvm_data.txt"), source = "libsvm")
+ model <- spark.lda(text, optimizer = "em")
+
+ stats <- summary(model, 10)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 11)
+ expect_true(is.null(vocabulary))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_equal(vocabulary, stats2$vocabulary)
+
+ unlink(modelPath)
+})
+
+test_that("spark.lda with text input", {
+ text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt"))
+ model <- spark.lda(text, optimizer = "online", features = "value")
+
+ stats <- summary(model)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 10)
+ expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_true(all.equal(vocabulary, stats2$vocabulary))
+
+ unlink(modelPath)
+})
+
+test_that("spark.posterior and spark.perplexity", {
+ text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt"))
+ model <- spark.lda(text, features = "value", k = 3)
+
+ # Assert perplexities are equal
+ stats <- summary(model)
+ logPerplexity <- spark.perplexity(model, text)
+ expect_equal(logPerplexity, stats$logPerplexity)
+
+ # Assert the sum of every topic distribution is equal to 1
+ posterior <- spark.posterior(model, text)
+ local.posterior <- collect(posterior)$topicDistribution
+ expect_equal(length(local.posterior), sum(unlist(local.posterior)))
+})
+
+sparkR.session.stop()