aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-09-25 00:43:22 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-25 00:43:22 -0700
commit922338812c03eba43f2f1a6c414d1b6b049811cf (patch)
tree2df940a08de0645e2b88ba69d0c63931f9ec1f2f /R/pkg
parent21fd12cb17b9e08a0cc49b4fda801af947a4183b (diff)
downloadspark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.gz
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.bz2
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.zip
[SPARK-9681] [ML] Support R feature interactions in RFormula
This integrates the Interaction feature transformer with SparkR R formula support (i.e. support `:`). To generate reasonable ML attribute names for feature interactions, it was necessary to add the ability to read attribute the original attribute names back from `StructField`, and also to specify custom group prefixes in `VectorAssembler`. This also has the side-benefit of cleaning up the double-underscores in the attributes generated for non-interaction terms. mengxr Author: Eric Liang <ekl@databricks.com> Closes #8830 from ericl/interaction-2.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R2
-rw-r--r--R/pkg/inst/tests/test_mllib.R10
2 files changed, 10 insertions, 2 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cea3d760d0..474ada5956 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
-#' operators are supported, including '~', '+', '-', and '.'.
+#' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index f272de78ad..032f8ec68b 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", {
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+test_that("feature interaction vs native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
@@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", {
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
- c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+ c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})