diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 25 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 9 |
2 files changed, 34 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 22762c5bbb..f30a92dfc8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -931,6 +931,26 @@ class DataFrame(object): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) + def crosstab(self, col1, col2): + """ + Computes a pair-wise frequency table of the given columns. Also known as a contingency + table. The number of distinct values for each column should be less than 1e4. The first + column of each row will be the distinct values of `col1` and the column names will be the + distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that + have no occurrences will have `null` as their counts. + :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. + + :param col1: The name of the first column. Distinct items will make the first item of + each row. + :param col2: The name of the second column. Distinct items will make the column names + of the DataFrame. + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the @@ -1423,6 +1443,11 @@ class DataFrameStatFunctions(object): cov.__doc__ = DataFrame.cov.__doc__ + def crosstab(self, col1, col2): + return self.df.crosstab(col1, col2) + + crosstab.__doc__ = DataFrame.crosstab.__doc__ + def freqItems(self, cols, support=None): return self.df.freqItems(cols, support) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d652c302a5..7ea6656d31 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -405,6 +405,15 @@ class SQLTests(ReusedPySparkTestCase): cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_crosstab(self): + df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() + ct = df.stat.crosstab("a", "b").collect() + ct = sorted(ct, key=lambda x: x[0]) + for i, row in enumerate(ct): + self.assertEqual(row[0], str(i)) + self.assertTrue(row[1], 1) + self.assertTrue(row[2], 1) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions |