aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R2
-rw-r--r--R/pkg/inst/tests/test_mllib.R8
2 files changed, 9 insertions, 1 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 258e354081..6a8bacaa55 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
-#' operators are supported, including '~' and '+'.
+#' operators are supported, including '~', '+', '-', and '.'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 29152a1168..3bef693247 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -40,3 +40,11 @@ test_that("predictions match with native glm", {
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+
+test_that("dot minus and intercept vs native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ . - Species + 0, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})