diff options
3 files changed, 47 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 68f3867ba6..9d6de9b6e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -30,7 +30,7 @@ object Statistics { /** * Compute the Pearson correlation matrix for the input RDD of Vectors. - * Returns NaN if either vector has 0 variance. + * Columns with 0 covariance produce NaN entries in the correlation matrix. * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. @@ -39,7 +39,7 @@ object Statistics { /** * Compute the correlation matrix for the input RDD of Vectors using the specified method. - * Methods currently supported: `pearson` (default), `spearman` + * Methods currently supported: `pearson` (default), `spearman`. * * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], @@ -55,20 +55,26 @@ object Statistics { /** * Compute the Pearson correlation for the input RDDs. - * Columns with 0 covariance produce NaN entries in the correlation matrix. + * Returns NaN if either vector has 0 variance. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Compute the correlation for the input RDDs using the specified method. - * Methods currently supported: pearson (default), spearman + * Methods currently supported: `pearson` (default), `spearman`. + * + * Note: the two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. * - * @param x RDD[Double] of the same cardinality as y - * @param y RDD[Double] of the same cardinality as x + * @param x RDD[Double] of the same cardinality as y. + * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` *@return A Double containing the correlation between the two input RDD[Double]s using the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 1f7de630e7..9bd0c2cd05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -89,20 +89,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => // add an extra element to signify the end of the list so that flatMap can flush the last // batch of duplicates - val padded = iter ++ - Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) - var lastVal = 0.0 - var firstRank = 0.0 - val idBuffer = new ArrayBuffer[Long]() + val end = -1L + val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) + val firstEntry = padded.next() + var lastVal = firstEntry._1._1 + var firstRank = firstEntry._2.toDouble + val idBuffer = ArrayBuffer(firstEntry._1._2) padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != Long.MinValue) { + if (v == lastVal && id != end) { idBuffer += id Iterator.empty } else { - val entries = if (idBuffer.size == 0) { - // edge case for the first value matching the initial value of lastVal - Iterator.empty - } else if (idBuffer.size == 1) { + val entries = if (idBuffer.size == 1) { Iterator((idBuffer(0), firstRank)) } else { val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index bce4251426..a3f76f77a5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -31,6 +31,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) val yData = Array(4.0, 5.0, 3.0) + val zeros = new Array[Double](3) val data = Seq( Vectors.dense(1.0, 0.0, 0.0, -2.0), Vectors.dense(4.0, 5.0, 0.0, 3.0), @@ -46,6 +47,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val p1 = Statistics.corr(x, y, "pearson") assert(approxEqual(expected, default)) assert(approxEqual(expected, p1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val p2 = Statistics.corr(x1, y1) + assert(approxEqual(expected, p2)) + } + + // RDD of zero variance + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z).isNaN()) } test("corr(x, y) spearman") { @@ -54,6 +67,18 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { val expected = 0.5 val s1 = Statistics.corr(x, y, "spearman") assert(approxEqual(expected, s1)) + + // numPartitions >= size for input RDDs + for (numParts <- List(xData.size, xData.size * 2)) { + val x1 = sc.parallelize(xData, numParts) + val y1 = sc.parallelize(yData, numParts) + val s2 = Statistics.corr(x1, y1, "spearman") + assert(approxEqual(expected, s2)) + } + + // RDD of zero variance => zero variance in ranks + val z = sc.parallelize(zeros) + assert(Statistics.corr(x, z, "spearman").isNaN()) } test("corr(X) default, pearson") { |