aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/stat.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-21 09:29:45 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-21 09:29:45 -0700
commit85708168341a9406c451df20af3374c0850ce166 (patch)
tree186ef9febf3e61c1c9a6825dc0b6378f203d7e8f /python/pyspark/mllib/stat.py
parent5a8f64f33632fbf89d16cade2e0e66c5ed60760b (diff)
downloadspark-85708168341a9406c451df20af3374c0850ce166.tar.gz
spark-85708168341a9406c451df20af3374c0850ce166.tar.bz2
spark-85708168341a9406c451df20af3374c0850ce166.zip
[SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector
Convert the input rdd to RDD of Vector. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2870 from davies/fix4023 and squashes the following commits: 1eac767 [Davies Liu] address comments 0871576 [Davies Liu] convert rdd into RDD of Vector
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r--python/pyspark/mllib/stat.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index a6019dadf7..84baf12b90 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,7 +22,7 @@ Python package for statistical functions in MLlib.
from functools import wraps
from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _to_java_object_rdd
+from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
@@ -107,7 +107,7 @@ class Statistics(object):
array([ 2., 0., 0., -2.])
"""
sc = rdd.ctx
- jrdd = _to_java_object_rdd(rdd)
+ jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
return MultivariateStatisticalSummary(sc, cStats)
@@ -163,14 +163,15 @@ class Statistics(object):
if type(y) == str:
raise TypeError("Use 'method=' to specify method name.")
- jx = _to_java_object_rdd(x)
if not y:
+ jx = _to_java_object_rdd(x.map(_convert_to_vector))
resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
bytes = sc._jvm.SerDe.dumps(resultMat)
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
- jy = _to_java_object_rdd(y)
+ jx = _to_java_object_rdd(x.map(float))
+ jy = _to_java_object_rdd(y.map(float))
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)