aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala37
1 files changed, 28 insertions, 9 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 5bd0521298..6de3840b3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark.mllib.stat
+import java.util.Random
+
import org.scalatest.FunSuite
+import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest
@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
// labels: 1.0 (2 / 6), 0.0 (4 / 6)
// feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
// feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
- val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
- new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
- new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
- new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
- new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
- new LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
+ val data = Seq(
+ LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
+ LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
+ LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
+ LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
+ LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
+ LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
for (numParts <- List(2, 4, 6, 8)) {
val chi = Statistics.chiSqTest(sc.parallelize(data, numParts))
val feature1 = chi(0)
@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
}
// Test that the right number of results is returned
- val numCols = 321
- val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
- new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0)))))
+ val numCols = 1001
+ val sparseData = Array(
+ new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
+ new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0)))))
val chi = Statistics.chiSqTest(sc.parallelize(sparseData))
assert(chi.size === numCols)
+ assert(chi(1000) != null) // SPARK-3087
+
+ // Detect continous features or labels
+ val random = new Random(11L)
+ val continuousLabel =
+ Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
+ intercept[SparkException] {
+ Statistics.chiSqTest(sc.parallelize(continuousLabel, 2))
+ }
+ val continuousFeature =
+ Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
+ intercept[SparkException] {
+ Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
+ }
}
}