aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-04-05 16:13:31 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-05 16:13:31 -0700
commitacffc43455d7b3e4000be4ff0175b8ea19cd280b (patch)
tree4f133b4e622c8e9531f93c0cecce11ef1289badc
parentf15806a8f8ca34288ddb2d74b9ff1972c8374b59 (diff)
downloadspark-acffc43455d7b3e4000be4ff0175b8ea19cd280b.tar.gz
spark-acffc43455d7b3e4000be4ff0175b8ea19cd280b.tar.bz2
spark-acffc43455d7b3e4000be4ff0175b8ea19cd280b.zip
[SPARK-6262][MLLIB]Implement missing methods for MultivariateStatisticalSummary
Add below methods in pyspark for MultivariateStatisticalSummary - normL1 - normL2 Author: lewuathe <lewuathe@me.com> Closes #5359 from Lewuathe/SPARK-6262 and squashes the following commits: cbe439e [lewuathe] Implement missing methods for MultivariateStatisticalSummary
-rw-r--r--python/pyspark/mllib/stat/_statistics.py6
-rw-r--r--python/pyspark/mllib/tests.py6
2 files changed, 12 insertions, 0 deletions
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
index 218ac148ca..1d83e9d483 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -49,6 +49,12 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
def min(self):
return self.call("min").toArray()
+ def normL1(self):
+ return self.call("normL1").toArray()
+
+ def normL2(self):
+ return self.call("normL2").toArray()
+
class Statistics(object):
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index dd3b66ce67..47dad7d12e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -357,6 +357,12 @@ class StatTests(PySparkTestCase):
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())
+ def test_col_norms(self):
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, len(summary.normL1()))
+ self.assertEqual(10, len(summary.normL2()))
+
class VectorUDTTests(PySparkTestCase):