aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-01-18 12:50:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-01-18 12:50:58 -0800
commit5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c (patch)
tree62ff8ea27e74acd448182dcba0a414995d118097 /mllib
parent44fcf992aa516153a43d7141d3b8e092f0671ba4 (diff)
downloadspark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.gz
spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.bz2
spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.zip
[SPARK-12346][ML] Missing attribute names in GLM for vector-type features
Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names. cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #10323 from ericl/spark-12346.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala38
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala4
3 files changed, 43 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 716bc63e00..7ff5ad143f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -70,19 +70,19 @@ class VectorAssembler(override val uid: String)
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
- group.attributes.get.map { attr =>
+ group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
} else {
- attr
+ attr.withName(c + "_" + i)
}
}
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
- Array.fill(numAttrs)(NumericAttribute.defaultAttr)
+ Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index dc20a5ec21..16e565d8b5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(attrs === expectedAttrs)
}
+ test("vector attribute generation") {
+ val formula = new RFormula().setFormula("id ~ vec")
+ val original = sqlContext.createDataFrame(
+ Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+ ).toDF("id", "vec")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("vec_0"), Some(1)),
+ new NumericAttribute(Some("vec_1"), Some(2))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("vector attribute generation with unnamed input attrs") {
+ val formula = new RFormula().setFormula("id ~ vec2")
+ val base = sqlContext.createDataFrame(
+ Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+ ).toDF("id", "vec")
+ val metadata = new AttributeGroup(
+ "vec2",
+ Array[Attribute](
+ NumericAttribute.defaultAttr,
+ NumericAttribute.defaultAttr)).toMetadata
+ val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("vec2_0"), Some(1)),
+ new NumericAttribute(Some("vec2_1"), Some(2))))
+ assert(attrs === expectedAttrs)
+ }
+
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index f7de7c1e93..dce994fdbd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -111,8 +111,8 @@ class VectorAssemblerSuite
assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
val userSalaryOut = features.getAttr(4)
assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
- assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
- assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+ assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
+ assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
}
test("read/write") {