aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-07-27 17:17:49 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-27 17:17:49 -0700
commit8ddfa52c208bf329c2b2c8909c6be04301e36083 (patch)
treee8482d5cee69d187b7f30c5807766c90539e518c /mllib/src/test
parentdafe8d857dff4c61981476282cbfe11f5c008078 (diff)
downloadspark-8ddfa52c208bf329c2b2c8909c6be04301e36083.tar.gz
spark-8ddfa52c208bf329c2b2c8909c6be04301e36083.tar.bz2
spark-8ddfa52c208bf329c2b2c8909c6be04301e36083.zip
[SPARK-9230] [ML] Support StringType features in RFormula
This adds StringType feature support via OneHotEncoder. As part of this task it was necessary to change RFormula to an Estimator, so that factor levels could be determined from the training dataset. Not sure if I am using uids correctly here, would be good to get reviewer help on that. cc mengxr Umbrella design doc: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit# Author: Eric Liang <ekl@databricks.com> Closes #7574 from ericl/string-features and squashes the following commits: f99131a [Eric Liang] comments 0bf3c26 [Eric Liang] update docs c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala64
2 files changed, 36 insertions, 29 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index c8d065f37a..c4b45aee06 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite {
test("parse simple formulas") {
checkParse("y ~ x", "y", Seq("x"))
+ checkParse("y ~ x + x", "y", Seq("x"))
checkParse("y ~ ._foo ", "y", Seq("._foo"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 79c4ccf02d..8148c553e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = sqlContext.createDataFrame(
Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
- val result = formula.transform(original)
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
Seq(
- (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0),
- (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0))
+ (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0),
+ (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0))
).toDF("id", "v1", "v2", "features", "label")
// TODO(ekl) make schema comparisons ignore metadata, to avoid .toString
assert(result.schema.toString == resultSchema.toString)
assert(resultSchema == expected.schema)
- assert(result.collect().toSeq == expected.collect().toSeq)
+ assert(result.collect() === expected.collect())
}
test("features column already exists") {
val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
- formula.transformSchema(original.schema)
+ formula.fit(original)
}
intercept[IllegalArgumentException] {
- formula.transform(original)
+ formula.fit(original)
}
}
test("label column already exists") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
- assert(resultSchema.toString == formula.transform(original).schema.toString)
+ assert(resultSchema.toString == model.transform(original).schema.toString)
}
test("label column already exists but is not double type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
+ val model = formula.fit(original)
intercept[IllegalArgumentException] {
- formula.transformSchema(original.schema)
+ model.transformSchema(original.schema)
}
intercept[IllegalArgumentException] {
- formula.transform(original)
+ model.transform(original)
}
}
test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
- val resultSchema = formula.transformSchema(original.schema)
+ val model = formula.fit(original)
+ val resultSchema = model.transformSchema(original.schema)
assert(resultSchema.length == 3)
assert(!resultSchema.exists(_.name == "label"))
- assert(resultSchema.toString == formula.transform(original).schema.toString)
+ assert(resultSchema.toString == model.transform(original).schema.toString)
}
-// TODO(ekl) enable after we implement string label support
-// test("transform string label") {
-// val formula = new RFormula().setFormula("name ~ id")
-// val original = sqlContext.createDataFrame(
-// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name")
-// val result = formula.transform(original)
-// val resultSchema = formula.transformSchema(original.schema)
-// val expected = sqlContext.createDataFrame(
-// Seq(
-// (1, "foo", Vectors.dense(Array(1.0)), 1.0),
-// (2, "bar", Vectors.dense(Array(2.0)), 0.0),
-// (3, "bar", Vectors.dense(Array(3.0)), 0.0))
-// ).toDF("id", "name", "features", "label")
-// assert(result.schema.toString == resultSchema.toString)
-// assert(result.collect().toSeq == expected.collect().toSeq)
-// }
+ test("encodes string terms") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+ (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+ (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(result.collect() === expected.collect())
+ }
}