aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2017-01-21 21:26:14 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-01-21 21:26:14 -0800
commit0c589e3713655f25547d6945a40786da900ec2fc (patch)
treefbf39946e63f951d815c4b8b423c305624985d96
parent3dcad9fab17297f9966026f29fefb5c726965a13 (diff)
downloadspark-0c589e3713655f25547d6945a40786da900ec2fc.tar.gz
spark-0c589e3713655f25547d6945a40786da900ec2fc.tar.bz2
spark-0c589e3713655f25547d6945a40786da900ec2fc.zip
[SPARK-19291][SPARKR][ML] spark.gaussianMixture supports output log-likelihood.
## What changes were proposed in this pull request? ```spark.gaussianMixture``` supports output total log-likelihood for the model like R ```mvnormalmixEM```. ## How was this patch tested? R unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16646 from yanboliang/spark-19291.
-rw-r--r--R/pkg/R/mllib_clustering.R5
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala12
3 files changed, 19 insertions, 5 deletions
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index fb8d9e75ad..fa40f9d0bf 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -98,7 +98,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
#' @param object a fitted gaussian mixture model.
#' @return \code{summary} returns summary of the fitted model, which is a list.
#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu),
-#' \code{sigma} (sigma), and \code{posterior} (posterior).
+#' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior).
#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
#' @rdname spark.gaussianMixture
#' @export
@@ -112,6 +112,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
sigmaList <- callJMethod(jobj, "sigma")
k <- callJMethod(jobj, "k")
dim <- callJMethod(jobj, "dim")
+ loglik <- callJMethod(jobj, "logLikelihood")
mu <- c()
for (i in 1 : k) {
start <- (i - 1) * dim + 1
@@ -129,7 +130,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
} else {
dataFrame(callJMethod(jobj, "posterior"))
}
- list(lambda = lambda, mu = mu, sigma = sigma,
+ list(lambda = lambda, mu = mu, sigma = sigma, loglik = loglik,
posterior = posterior, is.loaded = is.loaded)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index cfbdea5c04..9de8362cde 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -56,6 +56,10 @@ test_that("spark.gaussianMixture", {
# [,1] [,2]
# [1,] 0.2961543 0.160783
# [2,] 0.1607830 1.008878
+ #
+ #' model$loglik
+ #
+ # [1] -46.89499
# nolint end
data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808),
list(0.3295078, -0.8204684), list(0.4874291, 0.7383247),
@@ -72,9 +76,11 @@ test_that("spark.gaussianMixture", {
rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081)
rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874,
0.2961543, 0.160783, 0.1607830, 1.008878)
+ rLoglik <- -46.89499
expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+ expect_equal(unlist(stats$loglik), rLoglik, tolerance = 1e-3)
p <- collect(select(predict(model, df), "prediction"))
expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))
@@ -88,6 +94,7 @@ test_that("spark.gaussianMixture", {
expect_equal(stats$lambda, stats2$lambda)
expect_equal(unlist(stats$mu), unlist(stats2$mu))
expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+ expect_equal(unlist(stats$loglik), unlist(stats2$loglik))
unlink(modelPath)
})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index b708702959..9a98a8b18b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._
private[r] class GaussianMixtureWrapper private (
val pipeline: PipelineModel,
val dim: Int,
+ val logLikelihood: Double,
val isLoaded: Boolean = false) extends MLWritable {
private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
@@ -91,7 +92,10 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
.setStages(Array(rFormulaModel, gm))
.fit(data)
- new GaussianMixtureWrapper(pipeline, dim)
+ val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
+ val logLikelihood: Double = gmm.summary.logLikelihood
+
+ new GaussianMixtureWrapper(pipeline, dim, logLikelihood)
}
override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
@@ -105,7 +109,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
val pipelinePath = new Path(path, "pipeline").toString
val rMetadata = ("class" -> instance.getClass.getName) ~
- ("dim" -> instance.dim)
+ ("dim" -> instance.dim) ~
+ ("logLikelihood" -> instance.logLikelihood)
val rMetadataJson: String = compact(render(rMetadata))
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
@@ -124,7 +129,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
val rMetadata = parse(rMetadataStr)
val dim = (rMetadata \ "dim").extract[Int]
- new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
+ val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+ new GaussianMixtureWrapper(pipeline, dim, logLikelihood, isLoaded = true)
}
}
}