diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/stat/_statistics.py | 6 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 6 |
2 files changed, 12 insertions, 0 deletions
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 218ac148ca..1d83e9d483 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -49,6 +49,12 @@ class MultivariateStatisticalSummary(JavaModelWrapper): def min(self): return self.call("min").toArray() + def normL1(self): + return self.call("normL1").toArray() + + def normL2(self): + return self.call("normL2").toArray() + class Statistics(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index dd3b66ce67..47dad7d12e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -357,6 +357,12 @@ class StatTests(PySparkTestCase): summary = Statistics.colStats(data) self.assertEqual(10, summary.count()) + def test_col_norms(self): + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(10, len(summary.normL1())) + self.assertEqual(10, len(summary.normL2())) + class VectorUDTTests(PySparkTestCase): |