aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/stat/_statistics.py6
-rw-r--r--python/pyspark/mllib/tests.py6
2 files changed, 12 insertions, 0 deletions
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
index 218ac148ca..1d83e9d483 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -49,6 +49,12 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
def min(self):
return self.call("min").toArray()
+ def normL1(self):
+ return self.call("normL1").toArray()
+
+ def normL2(self):
+ return self.call("normL2").toArray()
+
class Statistics(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index dd3b66ce67..47dad7d12e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -357,6 +357,12 @@ class StatTests(PySparkTestCase):
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())
+ def test_col_norms(self):
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, len(summary.normL1()))
+ self.assertEqual(10, len(summary.normL2()))
+
class VectorUDTTests(PySparkTestCase):