diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 3558290b23..e0293dbc4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext .map(x => (x._1.label, x._1.features, x._2)) .toDF("label", "data", "preFilteredData") - val model = new ChiSqSelector() + val selector = new ChiSqSelector() .setNumTopFeatures(1) .setFeaturesCol("data") .setLabelCol("label") .setOutputCol("filtered") - model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { + selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 ~== vec2 absTol 1e-1) } + + selector.setPercentile(0.34).fit(df).transform(df) + .select("filtered", "preFilteredData").collect().foreach { + case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } test("ChiSqSelector read/write") { |