aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-25 16:00:09 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-25 16:00:09 -0700
commit54d13bed87fcf2968f77e1f1153e85184ec91d78 (patch)
treefe1aba5bcf9c5af86db892a6fd2facfa7114bab0 /mllib/src
parentff7cc45f521c63ce40f955b8995d52a79dca17b4 (diff)
downloadspark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.gz
spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.tar.bz2
spark-54d13bed87fcf2968f77e1f1153e85184ec91d78.zip
[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula
## What changes were proposed in this pull request? StringIndexerModel.transform sets the output column metadata to use name inputCol. It should not. Fixing this causes a problem with the metadata produced by RFormula. Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction. Note that "prefixes" is no longer accurate since internal strings may be replaced. ## How was this patch tested? Unit test which failed before this fix. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11965 from jkbradley/StringIndexer-fix.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
3 files changed, 22 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index e7ca7ada74..12a76dbbfb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -125,6 +125,7 @@ class RFormula(override val uid: String)
encoderStages += new StringIndexer()
.setInputCol(term)
.setOutputCol(indexCol)
+ prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _ =>
(term, term)
@@ -229,7 +230,7 @@ class RFormulaModel private[feature](
override def copy(extra: ParamMap): RFormulaModel = copyValues(
new RFormulaModel(uid, resolvedFormula, pipelineModel))
- override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)"
+ override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
private def transformLabel(dataset: DataFrame): DataFrame = {
val labelName = resolvedFormula.label
@@ -400,14 +401,10 @@ private class VectorAttributeRewriter(
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = group.attributes.get.map { attr =>
if (attr.name.isDefined) {
- val name = attr.name.get
- val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) }
- if (replacement.nonEmpty) {
- val (k, v) = replacement.headOption.get
- attr.withName(v + name.stripPrefix(k))
- } else {
- attr
+ val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) =>
+ curName.replace(from, to)
}
+ attr.withName(name)
} else {
attr
}
@@ -416,7 +413,7 @@ private class VectorAttributeRewriter(
}
val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
- dataset.select((otherCols :+ rewrittenCol): _*)
+ dataset.select(otherCols :+ rewrittenCol : _*)
}
override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index c579a0d68e..faa0f6f407 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -161,15 +161,14 @@ class StringIndexerModel (
}
val metadata = NominalAttribute.defaultAttr
- .withName($(inputCol)).withValues(labels).toMetadata()
+ .withName($(outputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
- val filteredDataset = (getHandleInvalid) match {
- case "skip" => {
+ val filteredDataset = getHandleInvalid match {
+ case "skip" =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
dataset.where(filterer(dataset($(inputCol))))
- }
case _ => dataset
}
filteredDataset.select(col("*"),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d40e69dced..2c3255ef33 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
.setLabels(Array("a", "b", "c"))
testDefaultReadWrite(t)
}
+
+ test("StringIndexer metadata") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ val transformed = indexer.transform(df)
+ val attrs =
+ NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
+ assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+ }
}