aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-12-05 00:32:58 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-12-05 00:32:58 -0800
commitbdfe7f67468ecfd9927a1fec60d6605dd05ebe3f (patch)
treef800250b34c14ae4021ed8fb5f3e3c952f7c6e48 /mllib
parente9730b707ddf6e344de3b3b8f43487f7b0f18e25 (diff)
downloadspark-bdfe7f67468ecfd9927a1fec60d6605dd05ebe3f.tar.gz
spark-bdfe7f67468ecfd9927a1fec60d6605dd05ebe3f.tar.bz2
spark-bdfe7f67468ecfd9927a1fec60d6605dd05ebe3f.zip
[SPARK-18625][ML] OneVsRestModel should support setFeaturesCol and setPredictionCol
## What changes were proposed in this pull request? add `setFeaturesCol` and `setPredictionCol` for `OneVsRestModel` ## How was this patch tested? added tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #16059 from zhengruifeng/ovrm_setCol.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala14
2 files changed, 22 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f4ab0a074c..e58b30d665 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -140,6 +140,14 @@ final class OneVsRestModel private[ml] (
this(uid, Metadata.empty, models.asScala.toArray)
}
+ /** @group setParam */
+ @Since("2.1.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -175,6 +183,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
+ model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
.withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3f9bcec427..aacb7921b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.StringIndexer
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -136,6 +137,17 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(outputFields.contains("p"))
}
+ test("SPARK-18625 : OneVsRestModel should support setFeaturesCol and setPredictionCol") {
+ val ova = new OneVsRest().setClassifier(new LogisticRegression)
+ val ovaModel = ova.fit(dataset)
+ val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
+ ovaModel.setFeaturesCol("fea")
+ ovaModel.setPredictionCol("pred")
+ val transformedDataset = ovaModel.transform(dataset2)
+ val outputFields = transformedDataset.schema.fieldNames.toSet
+ assert(outputFields === Set("y", "fea", "pred"))
+ }
+
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)