aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-04 13:31:44 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-04 13:31:44 -0700
commitcc70f174169f45c85d459126a68bbe43c0bec328 (patch)
treea455d9fab09fb11f6d78265243a1a301a6e5563b /python
parent7143904700435265975d36f073cce2833467e121 (diff)
downloadspark-cc70f174169f45c85d459126a68bbe43c0bec328.tar.gz
spark-cc70f174169f45c85d459126a68bbe43c0bec328.tar.bz2
spark-cc70f174169f45c85d459126a68bbe43c0bec328.zip
[SPARK-14334] [SQL] add toLocalIterator for Dataset/DataFrame
## What changes were proposed in this pull request? RDD.toLocalIterator() could be used to fetch one partition at a time to reduce the memory usage. Right now, for Dataset/Dataframe we have to use df.rdd.toLocalIterator, which is super slow also requires lots of memory (because of the Java serializer or even Kyro serializer). This PR introduce an optimized toLocalIterator for Dataset/DataFrame, which is much faster and requires much less memory. For a partition with 5 millions rows, `df.rdd.toIterator` took about 100 seconds, but df.toIterator took less than 7 seconds. For 10 millions row, rdd.toIterator will crash (not enough memory) with 4G heap, but df.toLocalIterator could finished in 12 seconds. The JDBC server has been updated to use DataFrame.toIterator. ## How was this patch tested? Existing tests. Author: Davies Liu <davies@databricks.com> Closes #12114 from davies/local_iterator.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py8
-rw-r--r--python/pyspark/sql/dataframe.py14
2 files changed, 18 insertions, 4 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 37574cea0b..cd1f64e8aa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2299,14 +2299,14 @@ class RDD(object):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
+
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
- for partition in range(self.getNumPartitions()):
- rows = self.context.runJob(self, lambda x: x, [partition])
- for row in rows:
- yield row
+ with SCCallSiteSync(self.context) as css:
+ port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(port, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7a69c4c70c..d473d6b534 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -241,6 +241,20 @@ class DataFrame(object):
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
+ @since(2.0)
+ def toLocalIterator(self):
+ """
+ Returns an iterator that contains all of the rows in this :class:`DataFrame`.
+ The iterator will consume as much memory as the largest partition in this DataFrame.
+
+ >>> list(df.toLocalIterator())
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.toPythonIterator()
+ return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+
+ @ignore_unicode_prefix
@since(1.3)
def limit(self, num):
"""Limits the result count to the number specified.