aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-09-29 04:30:42 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-09-29 04:30:42 -0700
commitf7082ac12518ae84d6d1d4b7330a9f12cf95e7c1 (patch)
treec657915e4a09298fb6e8ca77d127bbfd3f7c35e3
parenta19a1bb59411177caaf99581e89098826b7d0c7b (diff)
downloadspark-f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1.tar.gz
spark-f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1.tar.bz2
spark-f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1.zip
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.
## What changes were proposed in this pull request? Several performance improvement for ```ChiSqSelector```: 1, Keep ```selectedFeatures``` ordered ascendent. ```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times. 2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic. ## How was this patch tested? Existing unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15277 from yanboliang/spark-17704.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala45
-rw-r--r--project/MimaExcludes.scala3
2 files changed, 30 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 0f7c6e8bc0..706ce78f26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
- * @param selectedFeatures list of indices to select (filter).
+ * @param selectedFeatures list of indices to select (filter). Must be ordered asc
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
+ require(isSorted(selectedFeatures), "Array has to be sorted asc")
+
+ protected def isSorted(array: Array[Int]): Boolean = {
+ var i = 1
+ val len = array.length
+ while (i < len) {
+ if (array(i) < array(i-1)) return false
+ i += 1
+ }
+ true
+ }
+
/**
* Applies transformation on a vector.
*
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
- * @param filterIndices indices of features to filter
+ * @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
- val orderedIndices = filterIndices.sorted
features match {
case SparseVector(size, indices, values) =>
- val newSize = orderedIndices.length
+ val newSize = filterIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
- while (i < indices.length && j < orderedIndices.length) {
+ while (i < indices.length && j < filterIndices.length) {
indicesIdx = indices(i)
- filterIndicesIdx = orderedIndices(j)
+ filterIndicesIdx = filterIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
- Vectors.dense(orderedIndices.map(i => values(i)))
+ Vectors.dense(filterIndices.map(i => values(i)))
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
- .zipWithIndex.sortBy { case (res, _) => -res.statistic }
val features = selectorType match {
- case ChiSqSelector.KBest => chiSqTestResult
- .take(numTopFeatures)
- case ChiSqSelector.Percentile => chiSqTestResult
- .take((chiSqTestResult.length * percentile).toInt)
- case ChiSqSelector.FPR => chiSqTestResult
- .filter{ case (res, _) => res.pValue < alpha }
+ case ChiSqSelector.KBest =>
+ chiSqTestResult.zipWithIndex
+ .sortBy { case (res, _) => -res.statistic }
+ .take(numTopFeatures)
+ case ChiSqSelector.Percentile =>
+ chiSqTestResult.zipWithIndex
+ .sortBy { case (res, _) => -res.statistic }
+ .take((chiSqTestResult.length * percentile).toInt)
+ case ChiSqSelector.FPR =>
+ chiSqTestResult.zipWithIndex
+ .filter{ case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
- val indices = features.map { case (_, indices) => indices }
+ val indices = features.map { case (_, indices) => indices }.sorted
new ChiSqSelectorModel(indices)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8024fbd21b..4db3edb733 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -818,9 +818,6 @@ object MimaExcludes {
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
) ++ Seq(
- // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
- ) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
) ++ Seq(