aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/util.py')
-rw-r--r--python/pyspark/mllib/util.py69
1 files changed, 50 insertions, 19 deletions
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 0e5f4520b9..e24c144f45 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -19,7 +19,10 @@ import numpy as np
from pyspark.mllib.linalg import Vectors, SparseVector
from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib._common import _convert_vector
+from pyspark.mllib._common import _convert_vector, _deserialize_labeled_point
+from pyspark.rdd import RDD
+from pyspark.serializers import NoOpSerializer
+
class MLUtils:
@@ -105,24 +108,18 @@ class MLUtils:
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
>>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect()
>>> tempFile.close()
- >>> examples[0].label
- 1.0
- >>> examples[0].features.size
- 6
- >>> print examples[0].features
- [0: 1.0, 2: 2.0, 4: 3.0]
- >>> examples[1].label
- 0.0
- >>> examples[1].features.size
- 6
- >>> print examples[1].features
- []
- >>> examples[2].label
- 0.0
- >>> examples[2].features.size
- 6
- >>> print examples[2].features
- [1: 4.0, 3: 5.0, 5: 6.0]
+ >>> type(examples[0]) == LabeledPoint
+ True
+ >>> print examples[0]
+ (1.0,(6,[0,2,4],[1.0,2.0,3.0]))
+ >>> type(examples[1]) == LabeledPoint
+ True
+ >>> print examples[1]
+ (0.0,(6,[],[]))
+ >>> type(examples[2]) == LabeledPoint
+ True
+ >>> print examples[2]
+ (0.0,(6,[1,3,5],[4.0,5.0,6.0]))
>>> multiclass_examples[1].label
-1.0
"""
@@ -158,6 +155,40 @@ class MLUtils:
lines.saveAsTextFile(dir)
+ @staticmethod
+ def loadLabeledPoints(sc, path, minPartitions=None):
+ """
+ Load labeled points saved using RDD.saveAsTextFile.
+
+ @param sc: Spark context
+ @param path: file or directory path in any Hadoop-supported file
+ system URI
+ @param minPartitions: min number of partitions
+ @return: labeled data stored as an RDD of LabeledPoint
+
+ >>> from tempfile import NamedTemporaryFile
+ >>> from pyspark.mllib.util import MLUtils
+ >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
+ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
+ >>> tempFile = NamedTemporaryFile(delete=True)
+ >>> tempFile.close()
+ >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name)
+ >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
+ >>> type(loaded[0]) == LabeledPoint
+ True
+ >>> print examples[0]
+ (1.1,(3,[0,2],[-1.23,4.56e-07]))
+ >>> type(examples[1]) == LabeledPoint
+ True
+ >>> print examples[1]
+ (0.0,[1.01,2.02,3.03])
+ """
+ minPartitions = minPartitions or min(sc.defaultParallelism, 2)
+ jSerialized = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions)
+ serialized = RDD(jSerialized, sc, NoOpSerializer())
+ return serialized.map(lambda bytes: _deserialize_labeled_point(bytearray(bytes)))
+
+
def _test():
import doctest
from pyspark.context import SparkContext