aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala45
-rw-r--r--project/MimaExcludes.scala3
2 files changed, 30 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 0f7c6e8bc0..706ce78f26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
- * @param selectedFeatures list of indices to select (filter).
+ * @param selectedFeatures list of indices to select (filter). Must be ordered asc
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
+ require(isSorted(selectedFeatures), "Array has to be sorted asc")
+
+ protected def isSorted(array: Array[Int]): Boolean = {
+ var i = 1
+ val len = array.length
+ while (i < len) {
+ if (array(i) < array(i-1)) return false
+ i += 1
+ }
+ true
+ }
+
/**
* Applies transformation on a vector.
*
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
- * @param filterIndices indices of features to filter
+ * @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
- val orderedIndices = filterIndices.sorted
features match {
case SparseVector(size, indices, values) =>
- val newSize = orderedIndices.length
+ val newSize = filterIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
- while (i < indices.length && j < orderedIndices.length) {
+ while (i < indices.length && j < filterIndices.length) {
indicesIdx = indices(i)
- filterIndicesIdx = orderedIndices(j)
+ filterIndicesIdx = filterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
- Vectors.dense(orderedIndices.map(i => values(i)))
+ Vectors.dense(filterIndices.map(i => values(i)))
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
- .zipWithIndex.sortBy { case (res, _) => -res.statistic }
val features = selectorType match {
- case ChiSqSelector.KBest => chiSqTestResult
- .take(numTopFeatures)
- case ChiSqSelector.Percentile => chiSqTestResult
- .take((chiSqTestResult.length * percentile).toInt)
- case ChiSqSelector.FPR => chiSqTestResult
- .filter{ case (res, _) => res.pValue < alpha }
+ case ChiSqSelector.KBest =>
+ chiSqTestResult.zipWithIndex
+ .sortBy { case (res, _) => -res.statistic }
+ .take(numTopFeatures)
+ case ChiSqSelector.Percentile =>
+ chiSqTestResult.zipWithIndex
+ .sortBy { case (res, _) => -res.statistic }
+ .take((chiSqTestResult.length * percentile).toInt)
+ case ChiSqSelector.FPR =>
+ chiSqTestResult.zipWithIndex
+ .filter{ case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
- val indices = features.map { case (_, indices) => indices }
+ val indices = features.map { case (_, indices) => indices }.sorted
new ChiSqSelectorModel(indices)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8024fbd21b..4db3edb733 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -818,9 +818,6 @@ object MimaExcludes {
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
) ++ Seq(
- // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
- ) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
) ++ Seq(