From d4d762f275749a923356cd84de549b14c22cc3eb Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 23 Jul 2015 22:35:41 -0700 Subject: [SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable. The base classifier input and output columns are ignored in favor of the ones specified in OneVsRest. Author: Ram Sriharsha Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits: 6591dc6 [Ram Sriharsha] add documentation for params b7024b1 [Ram Sriharsha] cleanup f0e2bfb [Ram Sriharsha] merge with master 108d3d7 [Ram Sriharsha] merge with master 4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable --- .../apache/spark/ml/classification/OneVsRest.scala | 17 ++++++++++++++- .../spark/ml/classification/OneVsRestSuite.scala | 24 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40..1741f19dc9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams { /** * param for the base binary classifier that we reduce multiclass classification into. + * The base classifier input and output columns are ignored in favor of + * the ones specified in [[OneVsRest]]. * @group param */ val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") @@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String) set(classifier, value.asInstanceOf[ClassifierType]) } + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } @@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String) val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier - classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + val paramMap = new ParamMap() + paramMap.put(classifier.labelCol -> labelColName) + paramMap.put(classifier.featuresCol -> getFeaturesCol) + paramMap.put(classifier.predictionCol -> getPredictionCol) + classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] if (handlePersistence) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ea..3775292f6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { ova.fit(datasetWithLabelMetadata) } + test("SPARK-8092: ensure label features and prediction cols are configurable") { + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexed") + + val indexedDataset = labelIndexer + .fit(dataset) + .transform(dataset) + .drop("label") + .withColumnRenamed("features", "f") + + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression()) + .setLabelCol(labelIndexer.getOutputCol) + .setFeaturesCol("f") + .setPredictionCol("p") + + val ovaModel = ova.fit(indexedDataset) + val transformedDataset = ovaModel.transform(indexedDataset) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("p")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1) -- cgit v1.2.3