aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala19
2 files changed, 28 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index f9b840097f..5c43a41bee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
+
+ if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
+ dataset.schema(resolvedFormula.label).dataType == StringType) {
+ encoderStages += new StringIndexer()
+ .setInputCol(resolvedFormula.label)
+ .setOutputCol($(labelCol))
+ }
+
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
@@ -172,7 +180,7 @@ class RFormulaModel private[feature](
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
- if (hasLabelCol(schema)) {
+ if (hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index b56013008b..dc20a5ec21 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.collect() === expected.collect())
}
+ test("index string label") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val resultSchema = model.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq(
+ ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
+ ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
+ ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0))
+ ).toDF("id", "a", "b", "features", "label")
+ // assert(result.schema.toString == resultSchema.toString)
+ assert(result.collect() === expected.collect())
+ }
+
test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(