aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-11-01 17:00:00 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-11-01 17:00:00 -0700
commit91c33a0ca5c8287f710076ed7681e5aa13ca068f (patch)
treeea3e24b067e3b7ba1f340f0ed7906c80a64a36bd /mllib
parentb929537b6eb0f8f34497c3dbceea8045bf5dffdb (diff)
downloadspark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.tar.gz
spark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.tar.bz2
spark-91c33a0ca5c8287f710076ed7681e5aa13ca068f.zip
[SPARK-18088][ML] Various ChiSqSelector cleanups
## What changes were proposed in this pull request? - Renamed kbest to numTopFeatures - Renamed alpha to fpr - Added missing Since annotations - Doc cleanups ## How was this patch tested? Added new standardized unit tests for spark.ml. Improved existing unit test coverage a bit. Author: Joseph K. Bradley <joseph@databricks.com> Closes #15647 from jkbradley/chisqselector-follow-ups.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala59
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala45
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala135
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala17
5 files changed, 139 insertions, 121 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index d0385e220e..653fa41124 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params
with HasFeaturesCol with HasOutputCol with HasLabelCol {
/**
- * Number of features that selector will select (ordered by statistic value descending). If the
+ * Number of features that selector will select, ordered by ascending p-value. If the
* number of features is less than numTopFeatures, then this will select all features.
- * Only applicable when selectorType = "kbest".
+ * Only applicable when selectorType = "numTopFeatures".
* The default value of numTopFeatures is 50.
*
* @group param
*/
+ @Since("1.6.0")
final val numTopFeatures = new IntParam(this, "numTopFeatures",
- "Number of features that selector will select, ordered by statistics value descending. If the" +
+ "Number of features that selector will select, ordered by ascending p-value. If the" +
" number of features is < numTopFeatures, then this will select all features.",
ParamValidators.gtEq(1))
setDefault(numTopFeatures -> 50)
/** @group getParam */
+ @Since("1.6.0")
def getNumTopFeatures: Int = $(numTopFeatures)
/**
* Percentile of features that selector will select, ordered by statistics value descending.
* Only applicable when selectorType = "percentile".
* Default value is 0.1.
+ * @group param
*/
+ @Since("2.1.0")
final val percentile = new DoubleParam(this, "percentile",
- "Percentile of features that selector will select, ordered by statistics value descending.",
+ "Percentile of features that selector will select, ordered by ascending p-value.",
ParamValidators.inRange(0, 1))
setDefault(percentile -> 0.1)
/** @group getParam */
+ @Since("2.1.0")
def getPercentile: Double = $(percentile)
/**
* The highest p-value for features to be kept.
* Only applicable when selectorType = "fpr".
* Default value is 0.05.
+ * @group param
*/
- final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
+ final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.",
ParamValidators.inRange(0, 1))
- setDefault(alpha -> 0.05)
+ setDefault(fpr -> 0.05)
/** @group getParam */
- def getAlpha: Double = $(alpha)
+ def getFpr: Double = $(fpr)
/**
* The selector type of the ChisqSelector.
- * Supported options: "kbest" (default), "percentile" and "fpr".
+ * Supported options: "numTopFeatures" (default), "percentile", "fpr".
+ * @group param
*/
+ @Since("2.1.0")
final val selectorType = new Param[String](this, "selectorType",
"The selector type of the ChisqSelector. " +
- "Supported options: kbest (default), percentile and fpr.",
- ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
- setDefault(selectorType -> OldChiSqSelector.KBest)
+ "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "),
+ ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes))
+ setDefault(selectorType -> OldChiSqSelector.NumTopFeatures)
/** @group getParam */
+ @Since("2.1.0")
def getSelectorType: String = $(selectorType)
}
/**
* Chi-Squared feature selection, which selects categorical features to use for predicting a
* categorical label.
- * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
- * `kbest` chooses the `k` top features according to a chi-squared test.
- * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `fpr` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `kbest`, the default number of top features is 50.
+ * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
+ * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
+ * - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
+ * positive rate of selection.
+ * By default, the selection method is `numTopFeatures`, with the default number of top features
+ * set to 50.
*/
@Since("1.6.0")
final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -114,10 +125,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
def this() = this(Identifiable.randomUID("chiSqSelector"))
/** @group setParam */
- @Since("2.1.0")
- def setSelectorType(value: String): this.type = set(selectorType, value)
-
- /** @group setParam */
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
@@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
/** @group setParam */
@Since("2.1.0")
- def setAlpha(value: Double): this.type = set(alpha, value)
+ def setFpr(value: Double): this.type = set(fpr, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setSelectorType(value: String): this.type = set(selectorType, value)
/** @group setParam */
@Since("1.6.0")
@@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
.setSelectorType($(selectorType))
.setNumTopFeatures($(numTopFeatures))
.setPercentile($(percentile))
- .setAlpha($(alpha))
+ .setFpr($(fpr))
val model = selector.fit(input)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
- otherPairs.foreach { case (_, paramName: String) =>
+ val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType))
+ otherPairs.foreach { paramName: String =>
if (isSet(getParam(paramName))) {
logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 904000f50d..034e3625e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable {
selectorType: String,
numTopFeatures: Int,
percentile: Double,
- alpha: Double,
+ fpr: Double,
data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
new ChiSqSelector()
.setSelectorType(selectorType)
.setNumTopFeatures(numTopFeatures)
.setPercentile(percentile)
- .setAlpha(alpha)
+ .setFpr(fpr)
.fit(data.rdd)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f8276de4f2..f9156b6427 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
Loader.checkSchema[Data](dataFrame.schema)
val features = dataArray.rdd.map {
- case Row(feature: Int) => (feature)
+ case Row(feature: Int) => feature
}.collect()
new ChiSqSelectorModel(features)
@@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
/**
* Creates a ChiSquared feature selector.
- * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
- * `kbest` chooses the `k` top features according to a chi-squared test.
- * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `fpr` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `kbest`, the default number of top features is 50.
+ * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
+ * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
+ * - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
+ * positive rate of selection.
+ * By default, the selection method is `numTopFeatures`, with the default number of top features
+ * set to 50.
*/
@Since("1.3.0")
class ChiSqSelector @Since("2.1.0") () extends Serializable {
var numTopFeatures: Int = 50
var percentile: Double = 0.1
- var alpha: Double = 0.05
- var selectorType = ChiSqSelector.KBest
+ var fpr: Double = 0.05
+ var selectorType = ChiSqSelector.NumTopFeatures
/**
* The is the same to call this() and setNumTopFeatures(numTopFeatures)
@@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
}
@Since("2.1.0")
- def setAlpha(value: Double): this.type = {
- require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
- alpha = value
+ def setFpr(value: Double): this.type = {
+ require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]")
+ fpr = value
this
}
@Since("2.1.0")
def setSelectorType(value: String): this.type = {
- require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
+ require(ChiSqSelector.supportedSelectorTypes.contains(value),
s"ChiSqSelector Type: $value was not supported.")
selectorType = value
this
@@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
val features = selectorType match {
- case ChiSqSelector.KBest =>
+ case ChiSqSelector.NumTopFeatures =>
chiSqTestResult
.sortBy { case (res, _) => res.pValue }
.take(numTopFeatures)
@@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult
- .filter { case (res, _) => res.pValue < alpha }
+ .filter { case (res, _) => res.pValue < fpr }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
@@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
}
}
-@Since("2.1.0")
-object ChiSqSelector {
+private[spark] object ChiSqSelector {
- /** String name for `kbest` selector type. */
- private[spark] val KBest: String = "kbest"
+ /** String name for `numTopFeatures` selector type. */
+ val NumTopFeatures: String = "numTopFeatures"
/** String name for `percentile` selector type. */
- private[spark] val Percentile: String = "percentile"
+ val Percentile: String = "percentile"
/** String name for `fpr` selector type. */
private[spark] val FPR: String = "fpr"
- /** Set of selector type and param pairs that ChiSqSelector supports. */
- private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
- Percentile -> "percentile", FPR -> "alpha")
-
/** Set of selector types that ChiSqSelector supports. */
- private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
+ val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 6af06d82d6..80970fd744 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -19,85 +19,72 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.feature
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{Dataset, Row}
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
- test("Test Chi-Square selector") {
- import testImplicits._
- val data = Seq(
- LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
- LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
- LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
- LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
- )
+ @transient var dataset: Dataset[_] = _
- val preFilteredData = Seq(
- Vectors.dense(8.0),
- Vectors.dense(0.0),
- Vectors.dense(0.0),
- Vectors.dense(8.0)
- )
+ override def beforeAll(): Unit = {
+ super.beforeAll()
- val df = sc.parallelize(data.zip(preFilteredData))
- .map(x => (x._1.label, x._1.features, x._2))
- .toDF("label", "data", "preFilteredData")
-
- val selector = new ChiSqSelector()
- .setSelectorType("kbest")
- .setNumTopFeatures(1)
- .setFeaturesCol("data")
- .setLabelCol("label")
- .setOutputCol("filtered")
-
- selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
- case Row(vec1: Vector, vec2: Vector) =>
- assert(vec1 ~== vec2 absTol 1e-1)
- }
-
- selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
- .select("filtered", "preFilteredData").collect().foreach {
- case Row(vec1: Vector, vec2: Vector) =>
- assert(vec1 ~== vec2 absTol 1e-1)
- }
+ // Toy dataset, including the top feature for a chi-squared test.
+ // These data are chosen such that each feature's test has a distinct p-value.
+ /* To verify the results with R, run:
+ library(stats)
+ x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0)
+ x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0)
+ x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0)
+ y <- c(0.0, 1.0, 1.0, 2.0, 2.0)
+ chisq.test(x1,y)
+ chisq.test(x2,y)
+ chisq.test(x3,y)
+ */
+ dataset = spark.createDataFrame(Seq(
+ (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)),
+ (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)),
+ (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)),
+ (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)),
+ (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0))
+ )).toDF("label", "features", "topFeature")
+ }
- val preFilteredData2 = Seq(
- Vectors.dense(8.0, 7.0),
- Vectors.dense(0.0, 9.0),
- Vectors.dense(0.0, 9.0),
- Vectors.dense(8.0, 9.0)
- )
+ test("params") {
+ ParamsSuite.checkParams(new ChiSqSelector)
+ val model = new ChiSqSelectorModel("myModel",
+ new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4)))
+ ParamsSuite.checkParams(model)
+ }
- val df2 = sc.parallelize(data.zip(preFilteredData2))
- .map(x => (x._1.label, x._1.features, x._2))
- .toDF("label", "data", "preFilteredData")
+ test("Test Chi-Square selector: numTopFeatures") {
+ val selector = new ChiSqSelector()
+ .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
+ ChiSqSelectorSuite.testSelector(selector, dataset)
+ }
- selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
- .select("filtered", "preFilteredData").collect().foreach {
- case Row(vec1: Vector, vec2: Vector) =>
- assert(vec1 ~== vec2 absTol 1e-1)
- }
+ test("Test Chi-Square selector: percentile") {
+ val selector = new ChiSqSelector()
+ .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34)
+ ChiSqSelectorSuite.testSelector(selector, dataset)
}
- test("ChiSqSelector read/write") {
- val t = new ChiSqSelector()
- .setFeaturesCol("myFeaturesCol")
- .setLabelCol("myLabelCol")
- .setOutputCol("myOutputCol")
- .setNumTopFeatures(2)
- testDefaultReadWrite(t)
+ test("Test Chi-Square selector: fpr") {
+ val selector = new ChiSqSelector()
+ .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2)
+ ChiSqSelectorSuite.testSelector(selector, dataset)
}
- test("ChiSqSelectorModel read/write") {
- val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
- val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
- val newInstance = testDefaultReadWrite(instance)
- assert(newInstance.selectedFeatures === instance.selectedFeatures)
+ test("read/write") {
+ def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = {
+ assert(model.selectedFeatures === model2.selectedFeatures)
+ }
+ val nb = new ChiSqSelector
+ testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and not support other types") {
@@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
}
+
+object ChiSqSelectorSuite {
+
+ private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = {
+ selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect()
+ .foreach { case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-1)
+ }
+ }
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "selectorType" -> "percentile",
+ "numTopFeatures" -> 1,
+ "percentile" -> 0.12,
+ "outputCol" -> "myOutput"
+ )
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index ac702b4b7c..77219e5006 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
val preFilteredData =
- Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
+ Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
LabeledPoint(1.0, Vectors.dense(Array(0.0))),
LabeledPoint(1.0, Vectors.dense(Array(0.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0))))
val model = new ChiSqSelector(1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
- }.collect().toSet
- assert(filteredData == preFilteredData)
+ }.collect().toSeq
+ assert(filteredData === preFilteredData)
}
- test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
+ test("ChiSqSelector by fpr transform test (sparse & dense vector)") {
val labeledDiscreteData = sc.parallelize(
Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
val preFilteredData =
- Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
+ Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
- val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
+ val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr")
+ .setFpr(0.1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
- }.collect().toSet
- assert(filteredData == preFilteredData)
+ }.collect().toSeq
+ assert(filteredData === preFilteredData)
}
test("model load / save") {