diff options
author | Eric Liang <ekl@databricks.com> | 2015-07-27 17:17:49 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-27 17:17:49 -0700 |
commit | 8ddfa52c208bf329c2b2c8909c6be04301e36083 (patch) | |
tree | e8482d5cee69d187b7f30c5807766c90539e518c /R/pkg | |
parent | dafe8d857dff4c61981476282cbfe11f5c008078 (diff) | |
download | spark-8ddfa52c208bf329c2b2c8909c6be04301e36083.tar.gz spark-8ddfa52c208bf329c2b2c8909c6be04301e36083.tar.bz2 spark-8ddfa52c208bf329c2b2c8909c6be04301e36083.zip |
[SPARK-9230] [ML] Support StringType features in RFormula
This adds StringType feature support via OneHotEncoder. As part of this task it was necessary to change RFormula to an Estimator, so that factor levels could be determined from the training dataset.
Not sure if I am using uids correctly here, would be good to get reviewer help on that.
cc mengxr
Umbrella design doc: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit#
Author: Eric Liang <ekl@databricks.com>
Closes #7574 from ericl/string-features and squashes the following commits:
f99131a [Eric Liang] comments
0bf3c26 [Eric Liang] update docs
c302a2c [Eric Liang] fix tests
9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features
e713da3 [Eric Liang] comments
4d79193 [Eric Liang] revert to seq + distinct
169a085 [Eric Liang] tweak functional test
a230a47 [Eric Liang] Merge branch 'master' into string-features
72bd6f3 [Eric Liang] fix merge
d841cec [Eric Liang] Merge branch 'master' into string-features
5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015
b01c7c5 [Eric Liang] add test
8a637db [Eric Liang] encoder wip
a1d03f4 [Eric Liang] refactor into estimator
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/inst/tests/test_mllib.R | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index a492763344..29152a1168 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -35,8 +35,8 @@ test_that("glm and predict", { test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length, data = training) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) |