aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2015-09-25 00:43:22 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-25 00:43:22 -0700
commit922338812c03eba43f2f1a6c414d1b6b049811cf (patch)
tree2df940a08de0645e2b88ba69d0c63931f9ec1f2f /mllib/src/test/scala/org
parent21fd12cb17b9e08a0cc49b4fda801af947a4183b (diff)
downloadspark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.gz
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.tar.bz2
spark-922338812c03eba43f2f1a6c414d1b6b049811cf.zip
[SPARK-9681] [ML] Support R feature interactions in RFormula
This integrates the Interaction feature transformer with SparkR R formula support (i.e. support `:`). To generate reasonable ML attribute names for feature interactions, it was necessary to add the ability to read attribute the original attribute names back from `StructField`, and also to specify custom group prefixes in `VectorAssembler`. This also has the side-benefit of cleaning up the double-underscores in the attributes generated for non-interaction terms. mengxr Author: Eric Liang <ekl@databricks.com> Closes #8830 from ericl/interaction-2.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala89
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala76
2 files changed, 160 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index 436e66bab0..53798c659d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite {
formula: String,
label: String,
terms: Seq[String],
- schema: StructType = null) {
+ schema: StructType = new StructType) {
val resolved = RFormulaParser.parse(formula).resolve(schema)
assert(resolved.label == label)
- assert(resolved.terms == terms)
+ val simpleTerms = terms.map { t =>
+ if (t.contains(":")) {
+ t.split(":").toSeq
+ } else {
+ Seq(t)
+ }
+ }
+ assert(resolved.terms == simpleTerms)
}
test("parse simple formulas") {
checkParse("y ~ x", "y", Seq("x"))
checkParse("y ~ x + x", "y", Seq("x"))
- checkParse("y ~ ._foo ", "y", Seq("._foo"))
+ checkParse("y~x+z", "y", Seq("x", "z"))
+ checkParse("y ~ ._fo..o ", "y", Seq("._fo..o"))
checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
}
@@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite {
assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
}
+
+ test("parse interactions") {
+ checkParse("y ~ a:b", "y", Seq("a:b"))
+ checkParse("y ~ ._a:._x", "y", Seq("._a:._x"))
+ checkParse("y ~ foo:bar", "y", Seq("foo:bar"))
+ checkParse("y ~ a : b : c", "y", Seq("a:b:c"))
+ checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z"))
+ }
+
+ test("parse basic interactions with dot") {
+ val schema = (new StructType)
+ .add("a", "int", true)
+ .add("b", "long", false)
+ .add("c", "string", true)
+ .add("d", "string", true)
+ checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema)
+ checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema)
+ checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema)
+ }
+
+ // Test data generated in R with terms.formula(y ~ .:., data = iris)
+ test("parse all to all iris interactions") {
+ val schema = (new StructType)
+ .add("Sepal.Length", "double", true)
+ .add("Sepal.Width", "double", true)
+ .add("Petal.Length", "double", true)
+ .add("Petal.Width", "double", true)
+ .add("Species", "string", true)
+ checkParse(
+ "y ~ .:.",
+ "y",
+ Seq(
+ "Sepal.Length",
+ "Sepal.Width",
+ "Petal.Length",
+ "Petal.Width",
+ "Species",
+ "Sepal.Length:Sepal.Width",
+ "Sepal.Length:Petal.Length",
+ "Sepal.Length:Petal.Width",
+ "Sepal.Length:Species",
+ "Sepal.Width:Petal.Length",
+ "Sepal.Width:Petal.Width",
+ "Sepal.Width:Species",
+ "Petal.Length:Petal.Width",
+ "Petal.Length:Species",
+ "Petal.Width:Species"),
+ schema)
+ }
+
+ // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris)
+ test("parse interaction negation with iris") {
+ val schema = (new StructType)
+ .add("Sepal.Length", "double", true)
+ .add("Sepal.Width", "double", true)
+ .add("Petal.Length", "double", true)
+ .add("Petal.Width", "double", true)
+ .add("Species", "string", true)
+ checkParse("y ~ .:. - .:.", "y", Nil, schema)
+ checkParse(
+ "y ~ .:. - Species:.",
+ "y",
+ Seq(
+ "Sepal.Length",
+ "Sepal.Width",
+ "Petal.Length",
+ "Petal.Width",
+ "Sepal.Length:Sepal.Width",
+ "Sepal.Length:Petal.Length",
+ "Sepal.Length:Petal.Width",
+ "Sepal.Width:Petal.Length",
+ "Sepal.Width:Petal.Width",
+ "Petal.Length:Petal.Width"),
+ schema)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 6aed3243af..b56013008b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -118,9 +118,81 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
val expectedAttrs = new AttributeGroup(
"features",
Array(
- new BinaryAttribute(Some("a__bar"), Some(1)),
- new BinaryAttribute(Some("a__foo"), Some(2)),
+ new BinaryAttribute(Some("a_bar"), Some(1)),
+ new BinaryAttribute(Some("a_foo"), Some(2)),
new NumericAttribute(Some("b"), Some(3))))
assert(attrs === expectedAttrs)
}
+
+ test("numeric interaction") {
+ val formula = new RFormula().setFormula("a ~ b:c:d")
+ val original = sqlContext.createDataFrame(
+ Seq((1, 2, 4, 2), (2, 3, 4, 1))
+ ).toDF("a", "b", "c", "d")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, 2, 4, 2, Vectors.dense(16.0), 1.0),
+ (2, 3, 4, 1, Vectors.dense(12.0), 2.0))
+ ).toDF("a", "b", "c", "d", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("factor numeric interaction") {
+ val formula = new RFormula().setFormula("id ~ a:b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
+ (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
+ (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+ (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_baz:b"), Some(1)),
+ new NumericAttribute(Some("a_bar:b"), Some(2)),
+ new NumericAttribute(Some("a_foo:b"), Some(3))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("factor factor interaction") {
+ val formula = new RFormula().setFormula("id ~ a:b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
+ (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
+ (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0))
+ ).toDF("id", "a", "b", "features", "label")
+ assert(result.collect() === expected.collect())
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("a_bar:b_zq"), Some(1)),
+ new NumericAttribute(Some("a_bar:b_zz"), Some(2)),
+ new NumericAttribute(Some("a_foo:b_zq"), Some(3)),
+ new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
+ assert(attrs === expectedAttrs)
+ }
}