diff options
author | Davies Liu <davies@databricks.com> | 2015-11-05 16:34:10 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-11-05 16:34:10 -0800 |
commit | 244010624200eddea6dfd1b2c89f40be45212e96 (patch) | |
tree | b21694f1ea2bf1a0a6ce632171230fbd1f34d166 /R/pkg | |
parent | b6974f8fed1726a381636e996834111a8e7ced8d (diff) | |
download | spark-244010624200eddea6dfd1b2c89f40be45212e96.tar.gz spark-244010624200eddea6dfd1b2c89f40be45212e96.tar.bz2 spark-244010624200eddea6dfd1b2c89f40be45212e96.zip |
[SPARK-11542] [SPARKR] fix glm with long fomular
Because deparse() will break the long string into multiple lines, the deserialization will fail
Author: Davies Liu <davies@databricks.com>
Closes #9510 from davies/fix_glm.
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/R/mllib.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 12 |
2 files changed, 14 insertions, 1 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 60bfadb8e7..b0d73dd93a 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -48,8 +48,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, standardize = TRUE, solver = "auto") { family <- match.arg(family) + formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + "fitRModelFormula", formula, data@sdf, family, lambda, alpha, standardize, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032cfef061..4761e285a2 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -33,6 +33,18 @@ test_that("glm and predict", { expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") }) +test_that("glm should work with long formula", { + training <- createDataFrame(sqlContext, iris) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) |