diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-11-03 08:32:37 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-03 08:32:37 -0800 |
commit | f54ff19b1edd4903950cb334987a447445fa97ef (patch) | |
tree | be982fe7fa9ab98ba9f86aecbb860b6dc8693de1 /mllib/src/test/scala/org | |
parent | 3434572b141075f00698d94e6ee80febd3093c3b (diff) | |
download | spark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.gz spark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.bz2 spark-f54ff19b1edd4903950cb334987a447445fa97ef.zip |
[SPARK-11349][ML] Support transform string label for RFormula
Currently ```RFormula``` can only handle label with ```NumericType``` or ```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression training), we should also support label of ```StringType``` which is needed for Logistic Regression (glm with family = "binomial").
For label of ```StringType```, we should use ```StringIndexer``` to transform it to 0-based index.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9302 from yanboliang/spark-11349.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b56013008b..dc20a5ec21 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.collect() === expected.collect()) } + test("index string label") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), + ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)) + ).toDF("id", "a", "b", "features", "label") + // assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") val original = sqlContext.createDataFrame( |