aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorTimothy Hunter <timhunter@databricks.com>2016-04-29 23:13:03 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 23:13:03 -0700
commitbc36fe6e896ab0e64f6334b1e3fd6386d0c38238 (patch)
tree76f351715be2485233b45f0e676b533fbce9ed7b /R
parent43b149fb885a27f9467aab28e5195f6f03aadcf0 (diff)
downloadspark-bc36fe6e896ab0e64f6334b1e3fd6386d0c38238.tar.gz
spark-bc36fe6e896ab0e64f6334b1e3fd6386d0c38238.tar.bz2
spark-bc36fe6e896ab0e64f6334b1e3fd6386d0c38238.zip
[SPARK-14831][SPARKR] Make the SparkR MLlib API more consistent with Spark
## What changes were proposed in this pull request? This PR splits the MLlib algorithms into two flavors: - the R flavor, which tries to mimic the existing R API for these algorithms (and works as an S4 specialization for Spark dataframes) - the Spark flavor, which follows the same API and naming conventions as the rest of the MLlib algorithms in the other languages In practice, the former calls the latter. ## How was this patch tested? The tests for the various algorithms were adapted to be run against both interfaces. Author: Timothy Hunter <timhunter@databricks.com> Closes #12789 from thunterdb/14831.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE7
-rw-r--r--R/pkg/R/generics.R16
-rw-r--r--R/pkg/R/mllib.R155
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R141
4 files changed, 247 insertions, 72 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 647db22747..d2aebb3c85 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,12 +12,13 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
+ "spark.glm",
"predict",
"summary",
- "kmeans",
+ "spark.kmeans",
"fitted",
- "naiveBayes",
- "survreg")
+ "spark.naiveBayes",
+ "spark.survreg")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 3db8925730..a37cdf23f5 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1181,6 +1181,10 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
#' @export
setGeneric("year", function(x) { standardGeneric("year") })
+#' @rdname spark.glm
+#' @export
+setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })
+
#' @rdname glm
#' @export
setGeneric("glm")
@@ -1193,21 +1197,21 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
#' @export
setGeneric("rbind", signature = "...")
-#' @rdname kmeans
+#' @rdname spark.kmeans
#' @export
-setGeneric("kmeans")
+setGeneric("spark.kmeans", function(data, k, ...) { standardGeneric("spark.kmeans") })
#' @rdname fitted
#' @export
setGeneric("fitted")
-#' @rdname naiveBayes
+#' @rdname spark.naiveBayes
#' @export
-setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
+setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })
-#' @rdname survreg
+#' @rdname spark.survreg
#' @export
-setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
+setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
#' @rdname ml.save
#' @export
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index c2326ea116..4f62d7ce1b 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -17,6 +17,14 @@
# mllib.R: Provides methods for MLlib integration
+# Integration with R's standard functions.
+# Most of MLlib's argorithms are provided in two flavours:
+# - a specialization of the default R methods (glm). These methods try to respect
+# the inputs and the outputs of R's method to the largest extent, but some small differences
+# may exist.
+# - a set of methods that reflect the arguments of the other languages supported by Spark. These
+# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc.
+
#' @title S4 class that represents a generalized linear model
#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
#' @export
@@ -39,6 +47,54 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' Fits a generalized linear model
#'
+#' Fits a generalized linear model against a Spark DataFrame.
+#'
+#' @param data SparkDataFrame for training.
+#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
+#' operators are supported, including '~', '.', ':', '+', and '-'.
+#' @param family A description of the error distribution and link function to be used in the model.
+#' This can be a character string naming a family function, a family function or
+#' the result of a call to a family function. Refer R family at
+#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#' @param epsilon Positive convergence tolerance of iterations.
+#' @param maxit Integer giving the maximal number of IRLS iterations.
+#' @return a fitted generalized linear model
+#' @rdname spark.glm
+#' @export
+#' @examples
+#' \dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' data(iris)
+#' df <- createDataFrame(sqlContext, iris)
+#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian")
+#' summary(model)
+#' }
+setMethod(
+ "spark.glm",
+ signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, family = gaussian, epsilon = 1e-06, maxit = 25) {
+ if (is.character(family)) {
+ family <- get(family, mode = "function", envir = parent.frame())
+ }
+ if (is.function(family)) {
+ family <- family()
+ }
+ if (is.null(family$family)) {
+ print(family)
+ stop("'family' not recognized")
+ }
+
+ formula <- paste(deparse(formula), collapse = "")
+
+ jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
+ "fit", formula, data@sdf, family$family, family$link,
+ epsilon, as.integer(maxit))
+ return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+})
+
+#' Fits a generalized linear model (R-compliant).
+#'
#' Fits a generalized linear model, similarly to R's glm().
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
@@ -64,23 +120,7 @@ setClass("KMeansModel", representation(jobj = "jobj"))
#' }
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"),
function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
- if (is.character(family)) {
- family <- get(family, mode = "function", envir = parent.frame())
- }
- if (is.function(family)) {
- family <- family()
- }
- if (is.null(family$family)) {
- print(family)
- stop("'family' not recognized")
- }
-
- formula <- paste(deparse(formula), collapse = "")
-
- jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
- "fit", formula, data@sdf, family$family, family$link,
- epsilon, as.integer(maxit))
- return(new("GeneralizedLinearRegressionModel", jobj = jobj))
+ spark.glm(data, formula, family, epsilon, maxit)
})
#' Get the summary of a generalized linear model
@@ -188,7 +228,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
#' @export
#' @examples
#' \dontrun{
-#' model <- naiveBayes(y ~ x, trainingData)
+#' model <- spark.naiveBayes(trainingData, y ~ x)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
@@ -208,7 +248,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
#' @export
#' @examples
#' \dontrun{
-#' model <- naiveBayes(y ~ x, trainingData)
+#' model <- spark.naiveBayes(trainingData, y ~ x)
#' summary(model)
#'}
setMethod("summary", signature(object = "NaiveBayesModel"),
@@ -230,23 +270,23 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#'
#' Fit a k-means model, similarly to R's kmeans().
#'
-#' @param x SparkDataFrame for training
-#' @param centers Number of centers
-#' @param iter.max Maximum iteration number
-#' @param algorithm Algorithm choosen to fit the model
+#' @param data SparkDataFrame for training
+#' @param k Number of centers
+#' @param maxIter Maximum iteration number
+#' @param initializationMode Algorithm choosen to fit the model
#' @return A fitted k-means model
-#' @rdname kmeans
+#' @rdname spark.kmeans
#' @export
#' @examples
#' \dontrun{
-#' model <- kmeans(x, centers = 2, algorithm="random")
+#' model <- spark.kmeans(data, k = 2, initializationMode="random")
#' }
-setMethod("kmeans", signature(x = "SparkDataFrame"),
- function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
- columnNames <- as.array(colnames(x))
- algorithm <- match.arg(algorithm)
- jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
- centers, iter.max, algorithm, columnNames)
+setMethod("spark.kmeans", signature(data = "SparkDataFrame"),
+ function(data, k, maxIter = 10, initializationMode = c("random", "k-means||")) {
+ columnNames <- as.array(colnames(data))
+ initializationMode <- match.arg(initializationMode)
+ jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf,
+ k, maxIter, initializationMode, columnNames)
return(new("KMeansModel", jobj = jobj))
})
@@ -261,7 +301,7 @@ setMethod("kmeans", signature(x = "SparkDataFrame"),
#' @export
#' @examples
#' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
@@ -288,7 +328,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
#' @export
#' @examples
#' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
#' summary(model)
#' }
setMethod("summary", signature(object = "KMeansModel"),
@@ -322,7 +362,7 @@ setMethod("summary", signature(object = "KMeansModel"),
#' @export
#' @examples
#' \dontrun{
-#' model <- kmeans(trainingData, 2)
+#' model <- spark.kmeans(trainingData, 2)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#' }
@@ -333,30 +373,28 @@ setMethod("predict", signature(object = "KMeansModel"),
#' Fit a Bernoulli naive Bayes model
#'
-#' Fit a Bernoulli naive Bayes model, similarly to R package e1071's naiveBayes() while only
-#' categorical features are supported. The input should be a SparkDataFrame of observations instead
-#' of a contingency table.
+#' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported).
#'
+#' @param data SparkDataFrame for training
#' @param object A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param data SparkDataFrame for training
#' @param laplace Smoothing parameter
#' @return a fitted naive Bayes model
-#' @rdname naiveBayes
+#' @rdname spark.naiveBayes
#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/}
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, infert)
-#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' model <- spark.naiveBayes(df, education ~ ., laplace = 0)
#'}
-setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
- function(formula, data, laplace = 0, ...) {
- formula <- paste(deparse(formula), collapse = "")
- jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
- formula, data@sdf, laplace)
- return(new("NaiveBayesModel", jobj = jobj))
- })
+setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, laplace = 0, ...) {
+ formula <- paste(deparse(formula), collapse = "")
+ jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
+ formula, data@sdf, laplace)
+ return(new("NaiveBayesModel", jobj = jobj))
+ })
#' Save the Bernoulli naive Bayes model to the input path.
#'
@@ -371,7 +409,7 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, infert)
-#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' model <- spark.naiveBayes(education ~ ., df, laplace = 0)
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
@@ -396,7 +434,7 @@ setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
#' @export
#' @examples
#' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
@@ -446,7 +484,7 @@ setMethod("ml.save", signature(object = "GeneralizedLinearRegressionModel", path
#' @export
#' @examples
#' \dontrun{
-#' model <- kmeans(x, centers = 2, algorithm="random")
+#' model <- spark.kmeans(x, k = 2, initializationMode="random")
#' path <- "path/to/model"
#' ml.save(model, path)
#' }
@@ -489,29 +527,30 @@ ml.load <- function(path) {
#' Fit an accelerated failure time (AFT) survival regression model.
#'
-#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
+#' Fit an accelerated failure time (AFT) survival regression model on a Spark DataFrame.
#'
+#' @param data SparkDataFrame for training.
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
-#' @param data SparkDataFrame for training.
#' @return a fitted AFT survival regression model
-#' @rdname survreg
+#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
#' @export
#' @examples
#' \dontrun{
#' df <- createDataFrame(sqlContext, ovarian)
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
+#' model <- spark.survreg(Surv(df, futime, fustat) ~ ecog_ps + rx)
#' }
-setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
- function(formula, data, ...) {
+setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
+ function(data, formula, ...) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
return(new("AFTSurvivalRegressionModel", jobj = jobj))
})
+
#' Get the summary of an AFT survival regression model
#'
#' Returns the summary of an AFT survival regression model produced by survreg(),
@@ -523,7 +562,7 @@ setMethod("survreg", signature(formula = "formula", data = "SparkDataFrame"),
#' @export
#' @examples
#' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
#' summary(model)
#' }
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
@@ -548,7 +587,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
#' @export
#' @examples
#' \dontrun{
-#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
+#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#' }
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 6a822be121..18a4e78c99 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -25,6 +25,137 @@ sc <- sparkR.init()
sqlContext <- sparkRSQL.init(sc)
+test_that("formula of spark.glm", {
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ # directly calling the spark API
+ # dot minus and intercept vs native glm
+ model <- spark.glm(training, Sepal_Width ~ . - Species + 0)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+ # feature interaction vs native glm
+ model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+ # glm should work with long formula
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ training$LongLongLongLongLongName <- training$Sepal_Width
+ training$VeryLongLongLongLonLongName <- training$Sepal_Length
+ training$AnotherLongLongLongLongName <- training$Species
+ model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName +
+ AnotherLongLongLongLongName)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
+test_that("spark.glm and predict", {
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ # gaussian family
+ model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+ # poisson family
+ model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
+ family = poisson(link = identity))
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
+ rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
+ data = iris, family = poisson(link = identity)), iris))
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+
+ # Test stats::predict is working
+ x <- rnorm(15)
+ y <- x + rnorm(15)
+ expect_equal(length(predict(lm(y ~ x))), 15)
+})
+
+test_that("spark.glm summary", {
+ # gaussian family
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species))
+
+ rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
+
+ coefs <- unlist(stats$coefficients)
+ rCoefs <- unlist(rStats$coefficients)
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(
+ rownames(stats$coefficients) ==
+ c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
+ expect_equal(stats$dispersion, rStats$dispersion)
+ expect_equal(stats$null.deviance, rStats$null.deviance)
+ expect_equal(stats$deviance, rStats$deviance)
+ expect_equal(stats$df.null, rStats$df.null)
+ expect_equal(stats$df.residual, rStats$df.residual)
+ expect_equal(stats$aic, rStats$aic)
+
+ # binomial family
+ df <- suppressWarnings(createDataFrame(sqlContext, iris))
+ training <- df[df$Species %in% c("versicolor", "virginica"), ]
+ stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
+ family = binomial(link = "logit")))
+
+ rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
+ rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
+ family = binomial(link = "logit")))
+
+ coefs <- unlist(stats$coefficients)
+ rCoefs <- unlist(rStats$coefficients)
+ expect_true(all(abs(rCoefs - coefs) < 1e-4))
+ expect_true(all(
+ rownames(stats$coefficients) ==
+ c("(Intercept)", "Sepal_Length", "Sepal_Width")))
+ expect_equal(stats$dispersion, rStats$dispersion)
+ expect_equal(stats$null.deviance, rStats$null.deviance)
+ expect_equal(stats$deviance, rStats$deviance)
+ expect_equal(stats$df.null, rStats$df.null)
+ expect_equal(stats$df.residual, rStats$df.residual)
+ expect_equal(stats$aic, rStats$aic)
+
+ # Test summary works on base GLM models
+ baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
+ baseSummary <- summary(baseModel)
+ expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+})
+
+test_that("spark.glm save/load", {
+ training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species)
+ s <- summary(m)
+
+ modelPath <- tempfile(pattern = "glm", fileext = ".tmp")
+ ml.save(m, modelPath)
+ expect_error(ml.save(m, modelPath))
+ ml.save(m, modelPath, overwrite = TRUE)
+ m2 <- ml.load(modelPath)
+ s2 <- summary(m2)
+
+ expect_equal(s$coefficients, s2$coefficients)
+ expect_equal(rownames(s$coefficients), rownames(s2$coefficients))
+ expect_equal(s$dispersion, s2$dispersion)
+ expect_equal(s$null.deviance, s2$null.deviance)
+ expect_equal(s$deviance, s2$deviance)
+ expect_equal(s$df.null, s2$df.null)
+ expect_equal(s$df.residual, s2$df.residual)
+ expect_equal(s$aic, s2$aic)
+ expect_equal(s$iter, s2$iter)
+ expect_true(!s$is.loaded)
+ expect_true(s2$is.loaded)
+
+ unlink(modelPath)
+})
+
+
+
test_that("formula of glm", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
# dot minus and intercept vs native glm
@@ -153,14 +284,14 @@ test_that("glm save/load", {
unlink(modelPath)
})
-test_that("kmeans", {
+test_that("spark.kmeans", {
newIris <- iris
newIris$Species <- NULL
training <- suppressWarnings(createDataFrame(sqlContext, newIris))
take(training, 1)
- model <- kmeans(x = training, centers = 2)
+ model <- spark.kmeans(data = training, k = 2)
sample <- take(select(predict(model, training), "prediction"), 1)
expect_equal(typeof(sample$prediction), "integer")
expect_equal(sample$prediction, 1)
@@ -235,7 +366,7 @@ test_that("naiveBayes", {
t <- as.data.frame(Titanic)
t1 <- t[t$Freq > 0, -5]
df <- suppressWarnings(createDataFrame(sqlContext, t1))
- m <- naiveBayes(Survived ~ ., data = df)
+ m <- spark.naiveBayes(df, Survived ~ .)
s <- summary(m)
expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
@@ -264,7 +395,7 @@ test_that("naiveBayes", {
}
})
-test_that("survreg", {
+test_that("spark.survreg", {
# R code to reproduce the result.
#
#' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
@@ -290,7 +421,7 @@ test_that("survreg", {
data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
- model <- survreg(Surv(time, status) ~ x + sex, df)
+ model <- spark.survreg(df, Surv(time, status) ~ x + sex)
stats <- summary(model)
coefs <- as.vector(stats$coefficients[, 1])
rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)