diff options
author | Davies Liu <davies@databricks.com> | 2014-10-21 09:29:45 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-21 09:29:45 -0700 |
commit | 85708168341a9406c451df20af3374c0850ce166 (patch) | |
tree | 186ef9febf3e61c1c9a6825dc0b6378f203d7e8f /python/pyspark/mllib/stat.py | |
parent | 5a8f64f33632fbf89d16cade2e0e66c5ed60760b (diff) | |
download | spark-85708168341a9406c451df20af3374c0850ce166.tar.gz spark-85708168341a9406c451df20af3374c0850ce166.tar.bz2 spark-85708168341a9406c451df20af3374c0850ce166.zip |
[SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector
Convert the input rdd to RDD of Vector.
cc mengxr
Author: Davies Liu <davies@databricks.com>
Closes #2870 from davies/fix4023 and squashes the following commits:
1eac767 [Davies Liu] address comments
0871576 [Davies Liu] convert rdd into RDD of Vector
Diffstat (limited to 'python/pyspark/mllib/stat.py')
-rw-r--r-- | python/pyspark/mllib/stat.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index a6019dadf7..84baf12b90 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,7 +22,7 @@ Python package for statistical functions in MLlib. from functools import wraps from pyspark import PickleSerializer -from pyspark.mllib.linalg import _to_java_object_rdd +from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd __all__ = ['MultivariateStatisticalSummary', 'Statistics'] @@ -107,7 +107,7 @@ class Statistics(object): array([ 2., 0., 0., -2.]) """ sc = rdd.ctx - jrdd = _to_java_object_rdd(rdd) + jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector)) cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) return MultivariateStatisticalSummary(sc, cStats) @@ -163,14 +163,15 @@ class Statistics(object): if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = _to_java_object_rdd(x) if not y: + jx = _to_java_object_rdd(x.map(_convert_to_vector)) resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) bytes = sc._jvm.SerDe.dumps(resultMat) ser = PickleSerializer() return ser.loads(str(bytes)).toArray() else: - jy = _to_java_object_rdd(y) + jx = _to_java_object_rdd(x.map(float)) + jy = _to_java_object_rdd(y.map(float)) return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) |