diff options
author | Eric Liang <ekl@databricks.com> | 2016-01-18 12:50:58 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-01-18 12:50:58 -0800 |
commit | 5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c (patch) | |
tree | 62ff8ea27e74acd448182dcba0a414995d118097 /mllib/src/test | |
parent | 44fcf992aa516153a43d7141d3b8e092f0671ba4 (diff) | |
download | spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.gz spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.bz2 spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.zip |
[SPARK-12346][ML] Missing attribute names in GLM for vector-type features
Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names.
cc mengxr
Author: Eric Liang <ekl@databricks.com>
Closes #10323 from ericl/spark-12346.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 38 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 4 |
2 files changed, 40 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index dc20a5ec21..16e565d8b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } + test("vector attribute generation") { + val formula = new RFormula().setFormula("id ~ vec") + val original = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec_0"), Some(1)), + new NumericAttribute(Some("vec_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("vector attribute generation with unnamed input attrs") { + val formula = new RFormula().setFormula("id ~ vec2") + val base = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val metadata = new AttributeGroup( + "vec2", + Array[Attribute]( + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr)).toMetadata + val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec2_0"), Some(1)), + new NumericAttribute(Some("vec2_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index f7de7c1e93..dce994fdbd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -111,8 +111,8 @@ class VectorAssemblerSuite assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) val userSalaryOut = features.getAttr(4) assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) - assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) - assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0")) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1")) } test("read/write") { |