aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-17 20:53:18 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-17 20:53:18 -0700
commitc77f40668fbb5b8bca9a9b25c039895cb7a4a80c (patch)
tree19e891a3e28fca074a9c438281340cbf69fe5e03 /mllib/src/main
parent95470a03ae85d7d37d75f73435425a0e22918bc9 (diff)
downloadspark-c77f40668fbb5b8bca9a9b25c039895cb7a4a80c.tar.gz
spark-c77f40668fbb5b8bca9a9b25c039895cb7a4a80c.tar.bz2
spark-c77f40668fbb5b8bca9a9b25c039895cb7a4a80c.zip
[SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for number of distinct values
There is a bug determining the column index. dorx Author: Xiangrui Meng <meng@databricks.com> Closes #1997 from mengxr/chisq-index and squashes the following commits: 8fc2ab2 [Xiangrui Meng] fix col indexing bug and add a check for number of distinct values
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala37
2 files changed, 31 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index 3cf1028fbc..3cf4e807b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -155,7 +155,7 @@ object Statistics {
* :: Experimental ::
* Conduct Pearson's independence test for every feature against the label across the input RDD.
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which
- * the chi-squared statistic is computed.
+ * the chi-squared statistic is computed. All label and feature values must be categorical.
*
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
* Real-valued features will be treated as categorical for each distinct value.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
index 215de95db5..0089419c2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
@@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test
import breeze.linalg.{DenseMatrix => BDM}
import cern.jet.stat.Probability.chiSquareComplemented
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import scala.collection.mutable
+
/**
* Conduct the chi-squared test for the input RDDs using the specified method.
* Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted
@@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging {
*/
def chiSquaredFeatures(data: RDD[LabeledPoint],
methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
+ val maxCategories = 10000
val numCols = data.first().features.size
val results = new Array[ChiSqTestResult](numCols)
var labels: Map[Double, Int] = null
- // At most 100 columns at a time
- val batchSize = 100
+ // at most 1000 columns at a time
+ val batchSize = 1000
var batch = 0
while (batch * batchSize < numCols) {
// The following block of code can be cleaned up and made public as
// chiSquared(data: RDD[(V1, V2)])
val startCol = batch * batchSize
val endCol = startCol + math.min(batchSize, numCols - startCol)
- val pairCounts = data.flatMap { p =>
- // assume dense vectors
- p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) =>
- (col, feature, p.label)
+ val pairCounts = data.mapPartitions { iter =>
+ val distinctLabels = mutable.HashSet.empty[Double]
+ val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] =
+ Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*)
+ var i = 1
+ iter.flatMap { case LabeledPoint(label, features) =>
+ if (i % 1000 == 0) {
+ if (distinctLabels.size > maxCategories) {
+ throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ + s"found more than $maxCategories distinct label values.")
+ }
+ allDistinctFeatures.foreach { case (col, distinctFeatures) =>
+ if (distinctFeatures.size > maxCategories) {
+ throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ + s"found more than $maxCategories distinct values in column $col.")
+ }
+ }
+ }
+ i += 1
+ distinctLabels += label
+ features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) =>
+ allDistinctFeatures(col) += feature
+ (col, feature, label)
+ }
}
}.countByValue()