aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-02 16:51:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-06-02 16:51:17 -0700
commit89f21f66b5549524d1a6e4fb576a4f80d9fef903 (patch)
tree0457e47897eca1387f109577950e8d46f9f9462a /mllib
parent605ddbb27c8482fc0107b21c19d4e4ae19348f35 (diff)
downloadspark-89f21f66b5549524d1a6e4fb576a4f80d9fef903.tar.gz
spark-89f21f66b5549524d1a6e4fb576a4f80d9fef903.tar.bz2
spark-89f21f66b5549524d1a6e4fb576a4f80d9fef903.zip
[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output
The temporary column should be dropped after we get the prediction column. harsha2010 Author: Xiangrui Meng <meng@databricks.com> Closes #6592 from mengxr/SPARK-8049 and squashes the following commits: 1d89107 [Xiangrui Meng] use SparkFunSuite 6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala9
2 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da388..825f9ed1b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index f439f3261f..1d04ccb509 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {