aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-03 08:32:37 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-03 08:32:37 -0800
commitf54ff19b1edd4903950cb334987a447445fa97ef (patch)
treebe982fe7fa9ab98ba9f86aecbb860b6dc8693de1 /mllib/src/main/scala/org
parent3434572b141075f00698d94e6ee80febd3093c3b (diff)
downloadspark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.gz
spark-f54ff19b1edd4903950cb334987a447445fa97ef.tar.bz2
spark-f54ff19b1edd4903950cb334987a447445fa97ef.zip
[SPARK-11349][ML] Support transform string label for RFormula
Currently ```RFormula``` can only handle label with ```NumericType``` or ```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression training), we should also support label of ```StringType``` which is needed for Logistic Regression (glm with family = "binomial"). For label of ```StringType```, we should use ```StringIndexer``` to transform it to 0-based index. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9302 from yanboliang/spark-11349.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala10
1 files changed, 9 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f9b840097f..5c43a41bee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
+
+ if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
+ dataset.schema(resolvedFormula.label).dataType == StringType) {
+ encoderStages += new StringIndexer()
+ .setInputCol(resolvedFormula.label)
+ .setOutputCol($(labelCol))
+ }
+
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
@@ -172,7 +180,7 @@ class RFormulaModel private[feature](
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
- if (hasLabelCol(schema)) {
+ if (hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {