aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala22
1 files changed, 15 insertions, 7 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index eaa819c2e6..700f803490 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -22,6 +22,7 @@ import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.random.RandomRDDs
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
SpearmanCorrelation}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -42,10 +43,10 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
test("corr(x, y) pearson, 1 value in data") {
val x = sc.parallelize(Array(1.0))
val y = sc.parallelize(Array(4.0))
- intercept[RuntimeException] {
+ intercept[IllegalArgumentException] {
Statistics.corr(x, y, "pearson")
}
- intercept[RuntimeException] {
+ intercept[IllegalArgumentException] {
Statistics.corr(x, y, "spearman")
}
}
@@ -127,15 +128,22 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
assert(Correlations.getCorrelationFromName("pearson") === pearson)
assert(Correlations.getCorrelationFromName("spearman") === spearman)
- // Should throw IllegalArgumentException
- try {
+ intercept[IllegalArgumentException] {
Correlations.getCorrelationFromName("kendall")
- assert(false)
- } catch {
- case ie: IllegalArgumentException =>
}
}
+ ignore("Pearson correlation of very large uncorrelated values (SPARK-14533)") {
+ // The two RDDs should have 0 correlation because they're random;
+ // this should stay the same after shifting them by any amount
+ // In practice a large shift produces very large values which can reveal
+ // round-off problems
+ val a = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
+ val b = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
+ val p = Statistics.corr(a, b, method = "pearson")
+ assert(approxEqual(p, 0.0, 0.01))
+ }
+
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
if (v1.isNaN) {
v2.isNaN