aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorPeng <peng.meng@intel.com>2016-10-14 12:48:57 +0100
committerSean Owen <sowen@cloudera.com>2016-10-14 12:48:57 +0100
commitc8b612decba28e51789891f7881b6d4ebc50e2bb (patch)
tree33a908908c1647bc1636d6c372cf381510be902e /mllib
parenta1b136d05c6c458ae8211b0844bfc98d7693fa42 (diff)
downloadspark-c8b612decba28e51789891f7881b6d4ebc50e2bb.tar.gz
spark-c8b612decba28e51789891f7881b6d4ebc50e2bb.tar.bz2
spark-c8b612decba28e51789891f7881b6d4ebc50e2bb.zip
[SPARK-17870][MLLIB][ML] Change statistic to pValue for SelectKBest and SelectPercentile because of DoF difference
## What changes were proposed in this pull request? For feature selection method ChiSquareSelector, it is based on the ChiSquareTestResult.statistic (ChiSqure value) to select the features. It select the features with the largest ChiSqure value. But the Degree of Freedom (df) of ChiSqure value is different in Statistics.chiSqTest(RDD), and for different df, you cannot base on ChiSqure value to select features. So we change statistic to pValue for SelectKBest and SelectPercentile ## How was this patch tested? change existing test Author: Peng <peng.meng@intel.com> Closes #15444 from mpjlu/chisqure-bug.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala8
3 files changed, 9 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index c305b36278..f8276de4f2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -234,11 +234,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
val features = selectorType match {
case ChiSqSelector.KBest =>
chiSqTestResult
- .sortBy { case (res, _) => -res.statistic }
+ .sortBy { case (res, _) => res.pValue }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
chiSqTestResult
- .sortBy { case (res, _) => -res.statistic }
+ .sortBy { case (res, _) => res.pValue }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index dfebfc87ea..6af06d82d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -38,10 +38,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
)
val preFilteredData = Seq(
- Vectors.dense(0.0),
- Vectors.dense(6.0),
Vectors.dense(8.0),
- Vectors.dense(5.0)
+ Vectors.dense(0.0),
+ Vectors.dense(0.0),
+ Vectors.dense(8.0)
)
val df = sc.parallelize(data.zip(preFilteredData))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index ec23a4aa73..ac702b4b7c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -54,10 +54,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
val preFilteredData =
- Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
- LabeledPoint(1.0, Vectors.dense(Array(6.0))),
- LabeledPoint(1.0, Vectors.dense(Array(8.0))),
- LabeledPoint(2.0, Vectors.dense(Array(5.0))))
+ Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
+ LabeledPoint(1.0, Vectors.dense(Array(0.0))),
+ LabeledPoint(1.0, Vectors.dense(Array(0.0))),
+ LabeledPoint(2.0, Vectors.dense(Array(8.0))))
val model = new ChiSqSelector(1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))