diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala | 89 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 76 |
2 files changed, 160 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 436e66bab0..53798c659d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = null) { + schema: StructType = new StructType) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) - assert(resolved.terms == terms) + val simpleTerms = terms.map { t => + if (t.contains(":")) { + t.split(":").toSeq + } else { + Seq(t) + } + } + assert(resolved.terms == simpleTerms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) checkParse("y ~ x + x", "y", Seq("x")) - checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("y~x+z", "y", Seq("x", "z")) + checkParse("y ~ ._fo..o ", "y", Seq("._fo..o")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } @@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite { assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) } + + test("parse interactions") { + checkParse("y ~ a:b", "y", Seq("a:b")) + checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) + checkParse("y ~ foo:bar", "y", Seq("foo:bar")) + checkParse("y ~ a : b : c", "y", Seq("a:b:c")) + checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) + } + + test("parse basic interactions with dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + .add("d", "string", true) + checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema) + checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema) + checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema) + } + + // Test data generated in R with terms.formula(y ~ .:., data = iris) + test("parse all to all iris interactions") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse( + "y ~ .:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Species", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Length:Species", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Sepal.Width:Species", + "Petal.Length:Petal.Width", + "Petal.Length:Species", + "Petal.Width:Species"), + schema) + } + + // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris) + test("parse interaction negation with iris") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse("y ~ .:. - .:.", "y", Nil, schema) + checkParse( + "y ~ .:. - Species:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Petal.Length:Petal.Width"), + schema) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 6aed3243af..b56013008b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -118,9 +118,81 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedAttrs = new AttributeGroup( "features", Array( - new BinaryAttribute(Some("a__bar"), Some(1)), - new BinaryAttribute(Some("a__foo"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(1)), + new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } + + test("numeric interaction") { + val formula = new RFormula().setFormula("a ~ b:c:d") + val original = sqlContext.createDataFrame( + Seq((1, 2, 4, 2), (2, 3, 4, 1)) + ).toDF("a", "b", "c", "d") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) + ).toDF("a", "b", "c", "d", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) + assert(attrs === expectedAttrs) + } + + test("factor numeric interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_baz:b"), Some(1)), + new NumericAttribute(Some("a_bar:b"), Some(2)), + new NumericAttribute(Some("a_foo:b"), Some(3)))) + assert(attrs === expectedAttrs) + } + + test("factor factor interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_bar:b_zq"), Some(1)), + new NumericAttribute(Some("a_bar:b_zz"), Some(2)), + new NumericAttribute(Some("a_foo:b_zq"), Some(3)), + new NumericAttribute(Some("a_foo:b_zz"), Some(4)))) + assert(attrs === expectedAttrs) + } } |