aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-10-01 16:10:39 -0400
committerSean Owen <sowen@cloudera.com>2016-10-01 16:10:39 -0400
commitb88cb63da39786c07cb4bfa70afed32ec5eb3286 (patch)
tree1fcb0a85213238e47d0c98e2d79f792a75bad13e
parentaf6ece33d39cf305bd4a211d08a2f8e910c69bc1 (diff)
downloadspark-b88cb63da39786c07cb4bfa70afed32ec5eb3286.tar.gz
spark-b88cb63da39786c07cb4bfa70afed32ec5eb3286.tar.bz2
spark-b88cb63da39786c07cb4bfa70afed32ec5eb3286.zip
[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.
## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen <sowen@cloudera.com> Closes #15299 from srowen/SPARK-17704.2.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala22
-rwxr-xr-xpython/pyspark/ml/feature.py2
3 files changed, 13 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 9c131a4185..d0385e220e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (
import ChiSqSelectorModel._
- /** list of indices to select (filter). Must be ordered asc */
+ /** list of indices to select (filter). */
@Since("1.6.0")
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 706ce78f26..c305b36278 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
- * @param selectedFeatures list of indices to select (filter). Must be ordered asc
+ * @param selectedFeatures list of indices to select (filter).
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
- require(isSorted(selectedFeatures), "Array has to be sorted asc")
+ private val filterIndices = selectedFeatures.sorted
+ @deprecated("not intended for subclasses to use", "2.1.0")
protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
@@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
*/
@Since("1.3.0")
override def transform(vector: Vector): Vector = {
- compress(vector, selectedFeatures)
+ compress(vector)
}
/**
@@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
- * @param filterIndices indices of features to filter, must be ordered asc
*/
- private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
+ private def compress(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
@@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
- val chiSqTestResult = Statistics.chiSqTest(data)
+ val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
val features = selectorType match {
case ChiSqSelector.KBest =>
- chiSqTestResult.zipWithIndex
+ chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
- chiSqTestResult.zipWithIndex
+ chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
- chiSqTestResult.zipWithIndex
- .filter{ case (res, _) => res.pValue < alpha }
+ chiSqTestResult
+ .filter { case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
- val indices = features.map { case (_, indices) => indices }.sorted
+ val indices = features.map { case (_, index) => index }
new ChiSqSelectorModel(indices)
}
}
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 12a13849dc..64b21caa61 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
@since("2.0.0")
def selectedFeatures(self):
"""
- List of indices to select (filter). Must be ordered asc.
+ List of indices to select (filter).
"""
return self._call_java("selectedFeatures")