aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-01-18 12:50:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-01-18 12:50:58 -0800
commit5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c (patch)
tree62ff8ea27e74acd448182dcba0a414995d118097 /mllib/src/test/scala/org/apache
parent44fcf992aa516153a43d7141d3b8e092f0671ba4 (diff)
downloadspark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.gz
spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.tar.bz2
spark-5e492e9d5bc0992cbcffe64a9aaf3b334b173d2c.zip
[SPARK-12346][ML] Missing attribute names in GLM for vector-type features
Currently `summary()` fails on a GLM model fitted over a vector feature missing ML attrs, since the output feature attrs will also have no name. We can avoid this situation by forcing `VectorAssembler` to make up suitable names when inputs are missing names. cc mengxr Author: Eric Liang <ekl@databricks.com> Closes #10323 from ericl/spark-12346.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala38
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala4
2 files changed, 40 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index dc20a5ec21..16e565d8b5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -143,6 +143,44 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(attrs === expectedAttrs)
}
+ test("vector attribute generation") {
+ val formula = new RFormula().setFormula("id ~ vec")
+ val original = sqlContext.createDataFrame(
+ Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+ ).toDF("id", "vec")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("vec_0"), Some(1)),
+ new NumericAttribute(Some("vec_1"), Some(2))))
+ assert(attrs === expectedAttrs)
+ }
+
+ test("vector attribute generation with unnamed input attrs") {
+ val formula = new RFormula().setFormula("id ~ vec2")
+ val base = sqlContext.createDataFrame(
+ Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0)))
+ ).toDF("id", "vec")
+ val metadata = new AttributeGroup(
+ "vec2",
+ Array[Attribute](
+ NumericAttribute.defaultAttr,
+ NumericAttribute.defaultAttr)).toMetadata
+ val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata))
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array[Attribute](
+ new NumericAttribute(Some("vec2_0"), Some(1)),
+ new NumericAttribute(Some("vec2_1"), Some(2))))
+ assert(attrs === expectedAttrs)
+ }
+
test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
val original = sqlContext.createDataFrame(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index f7de7c1e93..dce994fdbd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -111,8 +111,8 @@ class VectorAssemblerSuite
assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3))
val userSalaryOut = features.getAttr(4)
assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4))
- assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
- assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
+ assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5).withName("ad_0"))
+ assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6).withName("ad_1"))
}
test("read/write") {