diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-11-03 08:32:37 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-11-03 08:32:37 -0800 |
commit | f54ff19b1edd4903950cb334987a447445fa97ef (patch) | |
tree | be982fe7fa9ab98ba9f86aecbb860b6dc8693de1 /mllib/src/main | |
parent | 3434572b141075f00698d94e6ee80febd3093c3b (diff) | |
download | spark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.gz spark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.bz2 spark-f54ff19b1edd4903950cb334987a447445fa97ef.zip |
[SPARK-11349][ML] Support transform string label for RFormula
Currently ```RFormula``` can only handle label with ```NumericType``` or ```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression training), we should also support label of ```StringType``` which is needed for Logistic Regression (glm with family = "binomial").
For label of ```StringType```, we should use ```StringIndexer``` to transform it to 0-based index.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9302 from yanboliang/spark-11349.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f9b840097f..5c43a41bee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R .setOutputCol($(featuresCol)) encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap) encoderStages += new ColumnPruner(tempColumns.toSet) + + if (dataset.schema.fieldNames.contains(resolvedFormula.label) && + dataset.schema(resolvedFormula.label).dataType == StringType) { + encoderStages += new StringIndexer() + .setInputCol(resolvedFormula.label) + .setOutputCol($(labelCol)) + } + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } @@ -172,7 +180,7 @@ class RFormulaModel private[feature]( override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) - if (hasLabelCol(schema)) { + if (hasLabelCol(withFeatures)) { withFeatures } else if (schema.exists(_.name == resolvedFormula.label)) { val nullable = schema(resolvedFormula.label).dataType match { |