From 5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 18 Jan 2016 12:50:58 -0800 Subject: [SPARK-12346][ML] Missing attribute names in GLM for vector-type features Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names. cc mengxr Author: Eric Liang Closes #10323 from ericl/spark-12346. --- .../apache/spark/ml/feature/RFormulaSuite.scala | 38 ++++++++++++++++++++++ .../spark/ml/feature/VectorAssemblerSuite.scala | 4 +-- 2 files changed, 40 insertions(+), 2 deletions(-) (limited to 'mllib/src/test') diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index dc20a5ec21..16e565d8b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } + test("vector attribute generation") { + val formula = new RFormula().setFormula("id ~ vec") + val original = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec_0"), Some(1)), + new NumericAttribute(Some("vec_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("vector attribute generation with unnamed input attrs") { + val formula = new RFormula().setFormula("id ~ vec2") + val base = sqlContext.createDataFrame( + Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) + ).toDF("id", "vec") + val metadata = new AttributeGroup( + "vec2", + Array[Attribute]( + NumericAttribute.defaultAttr, + NumericAttribute.defaultAttr)).toMetadata + val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("vec2_0"), Some(1)), + new NumericAttribute(Some("vec2_1"), Some(2)))) + assert(attrs === expectedAttrs) + } + test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") val original = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index f7de7c1e93..dce994fdbd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -111,8 +111,8 @@ class VectorAssemblerSuite assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) val userSalaryOut = features.getAttr(4) assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) - assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) - assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0")) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1")) } test("read/write") { -- cgit v1.2.3