aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2017-03-16 17:10:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-16 17:10:15 -0700
commit4c3200546c5c55e671988a957011417ba76a0600 (patch)
tree6d5b0aafc2fb302d0829c8da5b039e45646cf332 /mllib/src
parent2ea214dd05da929840c15891e908384cfa695ca8 (diff)
downloadspark-4c3200546c5c55e671988a957011417ba76a0600.tar.gz
spark-4c3200546c5c55e671988a957011417ba76a0600.tar.bz2
spark-4c3200546c5c55e671988a957011417ba76a0600.zip
[SPARK-19635][ML] DataFrame-based API for chi square test
## What changes were proposed in this pull request? Wrapper taking and return a DataFrame ## How was this patch tested? Copied unit tests from RDD-based API Author: Joseph K. Bradley <joseph@databricks.com> Closes #17110 from jkbradley/df-hypotests.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala81
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala98
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala11
4 files changed, 192 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala
new file mode 100644
index 0000000000..c3865ce6a9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquare.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
+import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
+import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.col
+
+
+/**
+ * :: Experimental ::
+ *
+ * Chi-square hypothesis testing for categorical data.
+ *
+ * See <a href="http://en.wikipedia.org/wiki/Chi-squared_test">Wikipedia</a> for more information
+ * on the Chi-squared test.
+ */
+@Experimental
+@Since("2.2.0")
+object ChiSquare {
+
+ /** Used to construct output schema of tests */
+ private case class ChiSquareResult(
+ pValues: Vector,
+ degreesOfFreedom: Array[Int],
+ statistics: Vector)
+
+ /**
+ * Conduct Pearson's independence test for every feature against the label across the input RDD.
+ * For each feature, the (feature, label) pairs are converted into a contingency matrix for which
+ * the Chi-squared statistic is computed. All label and feature values must be categorical.
+ *
+ * The null hypothesis is that the occurrence of the outcomes is statistically independent.
+ *
+ * @param dataset DataFrame of categorical labels and categorical features.
+ * Real-valued features will be treated as categorical for each distinct value.
+ * @param featuresCol Name of features column in dataset, of type `Vector` (`VectorUDT`)
+ * @param labelCol Name of label column in dataset, of any numerical type
+ * @return DataFrame containing the test result for every feature against the label.
+ * This DataFrame will contain a single Row with the following fields:
+ * - `pValues: Vector`
+ * - `degreesOfFreedom: Array[Int]`
+ * - `statistics: Vector`
+ * Each of these fields has one value per feature.
+ */
+ @Since("2.2.0")
+ def test(dataset: DataFrame, featuresCol: String, labelCol: String): DataFrame = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
+
+ SchemaUtils.checkColumnType(dataset.schema, featuresCol, new VectorUDT)
+ SchemaUtils.checkNumericType(dataset.schema, labelCol)
+ val rdd = dataset.select(col(labelCol).cast("double"), col(featuresCol)).as[(Double, Vector)]
+ .rdd.map { case (label, features) => OldLabeledPoint(label, OldVectors.fromML(features)) }
+ val testResults = OldStatistics.chiSqTest(rdd)
+ val pValues: Vector = Vectors.dense(testResults.map(_.pValue))
+ val degreesOfFreedom: Array[Int] = testResults.map(_.degreesOfFreedom)
+ val statistics: Vector = Vectors.dense(testResults.map(_.statistic))
+ spark.createDataFrame(Seq(ChiSquareResult(pValues, degreesOfFreedom, statistics)))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
index 9a63b8a5d6..ee51248e53 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
@@ -41,7 +41,7 @@ import org.apache.spark.rdd.RDD
*
* More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test
*/
-private[stat] object ChiSqTest extends Logging {
+private[spark] object ChiSqTest extends Logging {
/**
* @param name String name for the method.
@@ -71,6 +71,11 @@ private[stat] object ChiSqTest extends Logging {
}
/**
+ * Max number of categories when indexing labels and features
+ */
+ private[spark] val maxCategories: Int = 10000
+
+ /**
* Conduct Pearson's independence test for each feature against the label across the input RDD.
* The contingency table is constructed from the raw (feature, label) pairs and used to conduct
* the independence test.
@@ -78,7 +83,6 @@ private[stat] object ChiSqTest extends Logging {
*/
def chiSquaredFeatures(data: RDD[LabeledPoint],
methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
- val maxCategories = 10000
val numCols = data.first().features.size
val results = new Array[ChiSqTestResult](numCols)
var labels: Map[Double, Int] = null
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala
new file mode 100644
index 0000000000..b4bed82e4d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/ChiSquareSuite.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat
+
+import java.util.Random
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.stat.test.ChiSqTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class ChiSquareSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ test("test DataFrame of labeled points") {
+ // labels: 1.0 (2 / 6), 0.0 (4 / 6)
+ // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
+ // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
+ val data = Seq(
+ LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
+ LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
+ LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
+ LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
+ LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
+ LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
+ for (numParts <- List(2, 4, 6, 8)) {
+ val df = spark.createDataFrame(sc.parallelize(data, numParts))
+ val chi = ChiSquare.test(df, "features", "label")
+ val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) =
+ chi.select("pValues", "degreesOfFreedom", "statistics")
+ .as[(Vector, Array[Int], Vector)].head()
+ assert(pValues ~== Vectors.dense(0.6873, 0.6823) relTol 1e-4)
+ assert(degreesOfFreedom === Array(2, 3))
+ assert(statistics ~== Vectors.dense(0.75, 1.5) relTol 1e-4)
+ }
+ }
+
+ test("large number of features (SPARK-3087)") {
+ // Test that the right number of results is returned
+ val numCols = 1001
+ val sparseData = Array(
+ LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
+ LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0)))))
+ val df = spark.createDataFrame(sparseData)
+ val chi = ChiSquare.test(df, "features", "label")
+ val (pValues: Vector, degreesOfFreedom: Array[Int], statistics: Vector) =
+ chi.select("pValues", "degreesOfFreedom", "statistics")
+ .as[(Vector, Array[Int], Vector)].head()
+ assert(pValues.size === numCols)
+ assert(degreesOfFreedom.length === numCols)
+ assert(statistics.size === numCols)
+ assert(pValues(1000) !== null) // SPARK-3087
+ }
+
+ test("fail on continuous features or labels") {
+ val tooManyCategories: Int = 100000
+ assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " +
+ "tooManyCategories be large enough to cause ChiSqTest to throw an exception.")
+
+ val random = new Random(11L)
+ val continuousLabel = Seq.fill(tooManyCategories)(
+ LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
+ withClue("ChiSquare should throw an exception when given a continuous-valued label") {
+ intercept[SparkException] {
+ val df = spark.createDataFrame(continuousLabel)
+ ChiSquare.test(df, "features", "label")
+ }
+ }
+ val continuousFeature = Seq.fill(tooManyCategories)(
+ LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
+ withClue("ChiSquare should throw an exception when given continuous-valued features") {
+ intercept[SparkException] {
+ val df = spark.createDataFrame(continuousFeature)
+ ChiSquare.test(df, "features", "label")
+ }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 46fcebe132..992b876561 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -145,14 +145,17 @@ class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(chi(1000) != null) // SPARK-3087
// Detect continuous features or labels
+ val tooManyCategories: Int = 100000
+ assert(tooManyCategories > ChiSqTest.maxCategories, "This unit test requires that " +
+ "tooManyCategories be large enough to cause ChiSqTest to throw an exception.")
val random = new Random(11L)
- val continuousLabel =
- Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
+ val continuousLabel = Seq.fill(tooManyCategories)(
+ LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousLabel, 2))
}
- val continuousFeature =
- Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
+ val continuousFeature = Seq.fill(tooManyCategories)(
+ LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
}