aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBago Amirbekian <bago@databricks.com>2017-03-28 19:19:16 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-28 19:19:16 -0700
commita5c87707eaec5cacdfb703eb396dfc264bc54cda (patch)
tree123bcc6b24d6d03553faec689b183e7d7b43a7b3
parent7d432af8f3c47973550ea253dae0c23cd2961bde (diff)
downloadspark-a5c87707eaec5cacdfb703eb396dfc264bc54cda.tar.gz
spark-a5c87707eaec5cacdfb703eb396dfc264bc54cda.tar.bz2
spark-a5c87707eaec5cacdfb703eb396dfc264bc54cda.zip
[SPARK-20040][ML][PYTHON] pyspark wrapper for ChiSquareTest
## What changes were proposed in this pull request? A pyspark wrapper for spark.ml.stat.ChiSquareTest ## How was this patch tested? unit tests doctests Author: Bago Amirbekian <bago@databricks.com> Closes #17421 from MrBago/chiSquareTestWrapper.
-rw-r--r--dev/sparktestsupport/modules.py1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala6
-rw-r--r--python/docs/pyspark.ml.rst8
-rw-r--r--python/pyspark/ml/stat.py93
-rwxr-xr-xpython/pyspark/ml/tests.py31
5 files changed, 127 insertions, 12 deletions
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index eaf1f3a1db..246f5188a5 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -431,6 +431,7 @@ pyspark_ml = Module(
"pyspark.ml.linalg.__init__",
"pyspark.ml.recommendation",
"pyspark.ml.regression",
+ "pyspark.ml.stat",
"pyspark.ml.tuning",
"pyspark.ml.tests",
],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
index 21eba9a498..5b38ca73e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
@@ -46,9 +46,9 @@ object ChiSquareTest {
statistics: Vector)
/**
- * Conduct Pearson's independence test for every feature against the label across the input RDD.
- * For each feature, the (feature, label) pairs are converted into a contingency matrix for which
- * the Chi-squared statistic is computed. All label and feature values must be categorical.
+ * Conduct Pearson's independence test for every feature against the label. For each feature, the
+ * (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
+ * statistic is computed. All label and feature values must be categorical.
*
* The null hypothesis is that the occurrence of the outcomes is statistically independent.
*
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
index a68183445d..930646de9c 100644
--- a/python/docs/pyspark.ml.rst
+++ b/python/docs/pyspark.ml.rst
@@ -65,6 +65,14 @@ pyspark.ml.regression module
:undoc-members:
:inherited-members:
+pyspark.ml.stat module
+----------------------
+
+.. automodule:: pyspark.ml.stat
+ :members:
+ :undoc-members:
+ :inherited-members:
+
pyspark.ml.tuning module
------------------------
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
new file mode 100644
index 0000000000..db043ff68f
--- /dev/null
+++ b/python/pyspark/ml/stat.py
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import since, SparkContext
+from pyspark.ml.common import _java2py, _py2java
+from pyspark.ml.wrapper import _jvm
+
+
+class ChiSquareTest(object):
+ """
+ .. note:: Experimental
+
+ Conduct Pearson's independence test for every feature against the label. For each feature,
+ the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
+ statistic is computed. All label and feature values must be categorical.
+
+ The null hypothesis is that the occurrence of the outcomes is statistically independent.
+
+ :param dataset:
+ DataFrame of categorical labels and categorical features.
+ Real-valued features will be treated as categorical for each distinct value.
+ :param featuresCol:
+ Name of features column in dataset, of type `Vector` (`VectorUDT`).
+ :param labelCol:
+ Name of label column in dataset, of any numerical type.
+ :return:
+ DataFrame containing the test result for every feature against the label.
+ This DataFrame will contain a single Row with the following fields:
+ - `pValues: Vector`
+ - `degreesOfFreedom: Array[Int]`
+ - `statistics: Vector`
+ Each of these fields has one value per feature.
+
+ >>> from pyspark.ml.linalg import Vectors
+ >>> from pyspark.ml.stat import ChiSquareTest
+ >>> dataset = [[0, Vectors.dense([0, 0, 1])],
+ ... [0, Vectors.dense([1, 0, 1])],
+ ... [1, Vectors.dense([2, 1, 1])],
+ ... [1, Vectors.dense([3, 1, 1])]]
+ >>> dataset = spark.createDataFrame(dataset, ["label", "features"])
+ >>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
+ >>> chiSqResult.select("degreesOfFreedom").collect()[0]
+ Row(degreesOfFreedom=[3, 1, 0])
+
+ .. versionadded:: 2.2.0
+
+ """
+ @staticmethod
+ @since("2.2.0")
+ def test(dataset, featuresCol, labelCol):
+ """
+ Perform a Pearson's independence test using dataset.
+ """
+ sc = SparkContext._active_spark_context
+ javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest
+ args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)]
+ return _java2py(sc, javaTestObj.test(*args))
+
+
+if __name__ == "__main__":
+ import doctest
+ import pyspark.ml.stat
+ from pyspark.sql import SparkSession
+
+ globs = pyspark.ml.stat.__dict__.copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ spark = SparkSession.builder \
+ .master("local[2]") \
+ .appName("ml.stat tests") \
+ .getOrCreate()
+ sc = spark.sparkContext
+ globs['sc'] = sc
+ globs['spark'] = spark
+
+ failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ spark.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 527db9b667..571ac4bc1c 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -41,9 +41,7 @@ from shutil import rmtree
import tempfile
import array as pyarray
import numpy as np
-from numpy import (
- abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros)
-from numpy import sum as array_sum
+from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
import inspect
from pyspark import keyword_only, SparkContext
@@ -54,20 +52,19 @@ from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
-from pyspark.ml.linalg import (
- DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT,
- SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector)
+from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \
+ SparseMatrix, SparseVector, Vector, VectorUDT, Vectors
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
from pyspark.ml.recommendation import ALS
-from pyspark.ml.regression import (
- DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression)
+from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
+ LinearRegression
+from pyspark.ml.stat import ChiSquareTest
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams, JavaWrapper
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
-from pyspark.sql.utils import IllegalArgumentException
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -1741,6 +1738,22 @@ class WrapperTests(MLlibTestCase):
self.assertEqual(_java2py(self.sc, java_array), [])
+class ChiSquareTestTests(SparkSessionTestCase):
+
+ def test_chisquaretest(self):
+ data = [[0, Vectors.dense([0, 1, 2])],
+ [1, Vectors.dense([1, 1, 1])],
+ [2, Vectors.dense([2, 1, 0])]]
+ df = self.spark.createDataFrame(data, ['label', 'feat'])
+ res = ChiSquareTest.test(df, 'feat', 'label')
+ # This line is hitting the collect bug described in #17218, commented for now.
+ # pValues = res.select("degreesOfFreedom").collect())
+ self.assertIsInstance(res, DataFrame)
+ fieldNames = set(field.name for field in res.schema.fields)
+ expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
+ self.assertTrue(all(field in fieldNames for field in expectedFields))
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner: