diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-25 16:00:09 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-25 16:00:09 -0700 |
commit | 54d13bed87fcf2968f77e1f1153e85184ec91d78 (patch) | |
tree | fe1aba5bcf9c5af86db892a6fd2facfa7114bab0 /mllib/src/main/scala | |
parent | ff7cc45f521c63ce40f955b8995d52a79dca17b4 (diff) | |
download | spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.gz spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.bz2 spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.zip |
[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula
## What changes were proposed in this pull request?
StringIndexerModel.transform sets the output column metadata to use name inputCol. It should not. Fixing this causes a problem with the metadata produced by RFormula.
Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction.
Note that "prefixes" is no longer accurate since internal strings may be replaced.
## How was this patch tested?
Unit test which failed before this fix.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11965 from jkbradley/StringIndexer-fix.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 15 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 7 |
2 files changed, 9 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index e7ca7ada74..12a76dbbfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -125,6 +125,7 @@ class RFormula(override val uid: String) encoderStages += new StringIndexer() .setInputCol(term) .setOutputCol(indexCol) + prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => (term, term) @@ -229,7 +230,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" + override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label @@ -400,14 +401,10 @@ private class VectorAttributeRewriter( val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => if (attr.name.isDefined) { - val name = attr.name.get - val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) } - if (replacement.nonEmpty) { - val (k, v) = replacement.headOption.get - attr.withName(v + name.stripPrefix(k)) - } else { - attr + val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) => + curName.replace(from, to) } + attr.withName(name) } else { attr } @@ -416,7 +413,7 @@ private class VectorAttributeRewriter( } val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col) val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata) - dataset.select((otherCols :+ rewrittenCol): _*) + dataset.select(otherCols :+ rewrittenCol : _*) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index c579a0d68e..faa0f6f407 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -161,15 +161,14 @@ class StringIndexerModel ( } val metadata = NominalAttribute.defaultAttr - .withName($(inputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = (getHandleInvalid) match { - case "skip" => { + val filteredDataset = getHandleInvalid match { + case "skip" => val filterer = udf { label: String => labelToIndex.contains(label) } dataset.where(filterer(dataset($(inputCol)))) - } case _ => dataset } filteredDataset.select(col("*"), |