aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala11
1 files changed, 9 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 3558290b23..e0293dbc4b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
.map(x => (x._1.label, x._1.features, x._2))
.toDF("label", "data", "preFilteredData")
- val model = new ChiSqSelector()
+ val selector = new ChiSqSelector()
.setNumTopFeatures(1)
.setFeaturesCol("data")
.setLabelCol("label")
.setOutputCol("filtered")
- model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
+ selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}
+
+ selector.setPercentile(0.34).fit(df).transform(df)
+ .select("filtered", "preFilteredData").collect().foreach {
+ case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 ~== vec2 absTol 1e-1)
+ }
+
}
test("ChiSqSelector read/write") {