diff options
author | Eric Liang <ekl@databricks.com> | 2015-07-30 16:15:43 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-30 16:15:43 -0700 |
commit | e7905a9395c1a002f50bab29e16a729e14d4ed6f (patch) | |
tree | 37758d36fd51f330ca7b4ce2b9f9bb47784a2dcb /mllib/src/test | |
parent | be7be6d4c7d978c20e601d1f5f56ecb3479814cb (diff) | |
download | spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.gz spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.tar.bz2 spark-e7905a9395c1a002f50bab29e16a729e14d4ed6f.zip |
[SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula
Preview:
```
> summary(m)
features coefficients
1 (Intercept) 1.6765001
2 Sepal_Length 0.3498801
3 Species.versicolor -0.9833885
4 Species.virginica -1.0075104
```
Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit
cc mengxr
Author: Eric Liang <ekl@databricks.com>
Closes #7771 from ericl/summary and squashes the following commits:
ccd54c3 [Eric Liang] second pass
a5ca93b [Eric Liang] comments
2772111 [Eric Liang] clean up
70483ef [Eric Liang] fix test
7c247d4 [Eric Liang] Merge branch 'master' into summary
3c55024 [Eric Liang] working
8c539aa [Eric Liang] first pass
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala | 8 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 18 |
2 files changed, 22 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b..321eeb8439 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e9..6aed3243af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } |