aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-07-23 22:35:41 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-23 22:35:41 -0700
commitd4d762f275749a923356cd84de549b14c22cc3eb (patch)
treecb180273937502b6de0f3a40c17a2e4c78d05685 /mllib
parentd249636e59fabd8ca57a47dc2cbad9c4a4e7a750 (diff)
downloadspark-d4d762f275749a923356cd84de549b14c22cc3eb.tar.gz
spark-d4d762f275749a923356cd84de549b14c22cc3eb.tar.bz2
spark-d4d762f275749a923356cd84de549b14c22cc3eb.zip
[SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.
The base classifier input and output columns are ignored in favor of the ones specified in OneVsRest. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits: 6591dc6 [Ram Sriharsha] add documentation for params b7024b1 [Ram Sriharsha] cleanup f0e2bfb [Ram Sriharsha] merge with master 108d3d7 [Ram Sriharsha] merge with master 4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala24
2 files changed, 40 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index ea757c5e40..1741f19dc9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
/**
* param for the base binary classifier that we reduce multiclass classification into.
+ * The base classifier input and output columns are ignored in favor of
+ * the ones specified in [[OneVsRest]].
* @group param
*/
val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
@@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
set(classifier, value.asInstanceOf[ClassifierType])
}
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
@@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
- classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+ val paramMap = new ParamMap()
+ paramMap.put(classifier.labelCol -> labelColName)
+ paramMap.put(classifier.featuresCol -> getFeaturesCol)
+ paramMap.put(classifier.predictionCol -> getPredictionCol)
+ classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
if (handlePersistence) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 75cf5bd4ea..3775292f6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
ova.fit(datasetWithLabelMetadata)
}
+ test("SPARK-8092: ensure label features and prediction cols are configurable") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("indexed")
+
+ val indexedDataset = labelIndexer
+ .fit(dataset)
+ .transform(dataset)
+ .drop("label")
+ .withColumnRenamed("features", "f")
+
+ val ova = new OneVsRest()
+ ova.setClassifier(new LogisticRegression())
+ .setLabelCol(labelIndexer.getOutputCol)
+ .setFeaturesCol("f")
+ .setPredictionCol("p")
+
+ val ovaModel = ova.fit(indexedDataset)
+ val transformedDataset = ovaModel.transform(indexedDataset)
+ val outputFields = transformedDataset.schema.fieldNames.toSet
+ assert(outputFields.contains("p"))
+ }
+
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)