diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-11-01 17:00:00 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-11-01 17:00:00 -0700 |
commit | 91c33a0ca5c8287f710076ed7681e5aa13ca068f (patch) | |
tree | ea3e24b067e3b7ba1f340f0ed7906c80a64a36bd /mllib | |
parent | b929537b6eb0f8f34497c3dbceea8045bf5dffdb (diff) | |
download | spark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.tar.gz spark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.tar.bz2 spark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.zip |
[SPARK-18088][ML] Various ChiSqSelector cleanups
## What changes were proposed in this pull request?
- Renamed kbest to numTopFeatures
- Renamed alpha to fpr
- Added missing Since annotations
- Doc cleanups
## How was this patch tested?
Added new standardized unit tests for spark.ml.
Improved existing unit test coverage a bit.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #15647 from jkbradley/chisqselector-follow-ups.
Diffstat (limited to 'mllib')
5 files changed, 139 insertions, 121 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index d0385e220e..653fa41124 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params with HasFeaturesCol with HasOutputCol with HasLabelCol { /** - * Number of features that selector will select (ordered by statistic value descending). If the + * Number of features that selector will select, ordered by ascending p-value. If the * number of features is less than numTopFeatures, then this will select all features. - * Only applicable when selectorType = "kbest". + * Only applicable when selectorType = "numTopFeatures". * The default value of numTopFeatures is 50. * * @group param */ + @Since("1.6.0") final val numTopFeatures = new IntParam(this, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value descending. If the" + + "Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features.", ParamValidators.gtEq(1)) setDefault(numTopFeatures -> 50) /** @group getParam */ + @Since("1.6.0") def getNumTopFeatures: Int = $(numTopFeatures) /** * Percentile of features that selector will select, ordered by statistics value descending. * Only applicable when selectorType = "percentile". * Default value is 0.1. + * @group param */ + @Since("2.1.0") final val percentile = new DoubleParam(this, "percentile", - "Percentile of features that selector will select, ordered by statistics value descending.", + "Percentile of features that selector will select, ordered by ascending p-value.", ParamValidators.inRange(0, 1)) setDefault(percentile -> 0.1) /** @group getParam */ + @Since("2.1.0") def getPercentile: Double = $(percentile) /** * The highest p-value for features to be kept. * Only applicable when selectorType = "fpr". * Default value is 0.05. + * @group param */ - final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", + final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) - setDefault(alpha -> 0.05) + setDefault(fpr -> 0.05) /** @group getParam */ - def getAlpha: Double = $(alpha) + def getFpr: Double = $(fpr) /** * The selector type of the ChisqSelector. - * Supported options: "kbest" (default), "percentile" and "fpr". + * Supported options: "numTopFeatures" (default), "percentile", "fpr". + * @group param */ + @Since("2.1.0") final val selectorType = new Param[String](this, "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", - ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) - setDefault(selectorType -> OldChiSqSelector.KBest) + "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "), + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes)) + setDefault(selectorType -> OldChiSqSelector.NumTopFeatures) /** @group getParam */ + @Since("2.1.0") def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -114,10 +125,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str def this() = this(Identifiable.randomUID("chiSqSelector")) /** @group setParam */ - @Since("2.1.0") - def setSelectorType(value: String): this.type = set(selectorType, value) - - /** @group setParam */ @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) @@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = set(alpha, value) + def setFpr(value: Double): this.type = set(fpr, value) + + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) /** @group setParam */ @Since("1.6.0") @@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str .setSelectorType($(selectorType)) .setNumTopFeatures($(numTopFeatures)) .setPercentile($(percentile)) - .setAlpha($(alpha)) + .setFpr($(fpr)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) - otherPairs.foreach { case (_, paramName: String) => + val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType)) + otherPairs.foreach { paramName: String => if (isSet(getParam(paramName))) { logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d..034e3625e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable { selectorType: String, numTopFeatures: Int, percentile: Double, - alpha: Double, + fpr: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { new ChiSqSelector() .setSelectorType(selectorType) .setNumTopFeatures(numTopFeatures) .setPercentile(percentile) - .setAlpha(alpha) + .setFpr(fpr) .fit(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f8276de4f2..f9156b6427 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { Loader.checkSchema[Data](dataFrame.schema) val features = dataArray.rdd.map { - case Row(feature: Int) => (feature) + case Row(feature: Int) => feature }.collect() new ChiSqSelectorModel(features) @@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 - var alpha: Double = 0.05 - var selectorType = ChiSqSelector.KBest + var fpr: Double = 0.05 + var selectorType = ChiSqSelector.NumTopFeatures /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } @Since("2.1.0") - def setAlpha(value: Double): this.type = { - require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") - alpha = value + def setFpr(value: Double): this.type = { + require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]") + fpr = value this } @Since("2.1.0") def setSelectorType(value: String): this.type = { - require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + require(ChiSqSelector.supportedSelectorTypes.contains(value), s"ChiSqSelector Type: $value was not supported.") selectorType = value this @@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { - case ChiSqSelector.KBest => + case ChiSqSelector.NumTopFeatures => chiSqTestResult .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) @@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult - .filter { case (res, _) => res.pValue < alpha } + .filter { case (res, _) => res.pValue < fpr } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } @@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } -@Since("2.1.0") -object ChiSqSelector { +private[spark] object ChiSqSelector { - /** String name for `kbest` selector type. */ - private[spark] val KBest: String = "kbest" + /** String name for `numTopFeatures` selector type. */ + val NumTopFeatures: String = "numTopFeatures" /** String name for `percentile` selector type. */ - private[spark] val Percentile: String = "percentile" + val Percentile: String = "percentile" /** String name for `fpr` selector type. */ private[spark] val FPR: String = "fpr" - /** Set of selector type and param pairs that ChiSqSelector supports. */ - private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", - Percentile -> "percentile", FPR -> "alpha") - /** Set of selector types that ChiSqSelector supports. */ - private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) + val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6af06d82d6..80970fd744 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -19,85 +19,72 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test Chi-Square selector") { - import testImplicits._ - val data = Seq( - LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), - LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), - LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), - LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) - ) + @transient var dataset: Dataset[_] = _ - val preFilteredData = Seq( - Vectors.dense(8.0), - Vectors.dense(0.0), - Vectors.dense(0.0), - Vectors.dense(8.0) - ) + override def beforeAll(): Unit = { + super.beforeAll() - val df = sc.parallelize(data.zip(preFilteredData)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") - - val selector = new ChiSqSelector() - .setSelectorType("kbest") - .setNumTopFeatures(1) - .setFeaturesCol("data") - .setLabelCol("label") - .setOutputCol("filtered") - - selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } - - selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + // Toy dataset, including the top feature for a chi-squared test. + // These data are chosen such that each feature's test has a distinct p-value. + /* To verify the results with R, run: + library(stats) + x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0) + x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0) + x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0) + y <- c(0.0, 1.0, 1.0, 2.0, 2.0) + chisq.test(x1,y) + chisq.test(x2,y) + chisq.test(x3,y) + */ + dataset = spark.createDataFrame(Seq( + (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)), + (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)), + (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)), + (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)), + (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0)) + )).toDF("label", "features", "topFeature") + } - val preFilteredData2 = Seq( - Vectors.dense(8.0, 7.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(8.0, 9.0) - ) + test("params") { + ParamsSuite.checkParams(new ChiSqSelector) + val model = new ChiSqSelectorModel("myModel", + new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4))) + ParamsSuite.checkParams(model) + } - val df2 = sc.parallelize(data.zip(preFilteredData2)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") + test("Test Chi-Square selector: numTopFeatures") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) + ChiSqSelectorSuite.testSelector(selector, dataset) + } - selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + test("Test Chi-Square selector: percentile") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelector read/write") { - val t = new ChiSqSelector() - .setFeaturesCol("myFeaturesCol") - .setLabelCol("myLabelCol") - .setOutputCol("myOutputCol") - .setNumTopFeatures(2) - testDefaultReadWrite(t) + test("Test Chi-Square selector: fpr") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelectorModel read/write") { - val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) - val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.selectedFeatures === instance.selectedFeatures) + test("read/write") { + def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = { + assert(model.selectedFeatures === model2.selectedFeatures) + } + val nb = new ChiSqSelector + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext } } } + +object ChiSqSelectorSuite { + + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { + selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + .foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "selectorType" -> "percentile", + "numTopFeatures" -> 1, + "percentile" -> 0.12, + "outputCol" -> "myOutput" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ac702b4b7c..77219e5006 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } - test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + test("ChiSqSelector by fpr transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize( Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) + val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr") + .setFpr(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } test("model load / save") { |