aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorAaron Staple <aaron.staple@gmail.com>2014-09-16 11:45:35 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-16 11:45:35 -0700
commit8e7ae477ba40a064d27cf149aa211ff6108fe239 (patch)
tree3f30546913a7e1ca882e65fcbac721ed0ad36258 /python
parent30f288ae34a67307aa45b7aecbd0d02a0a14fe69 (diff)
downloadspark-8e7ae477ba40a064d27cf149aa211ff6108fe239.tar.gz
spark-8e7ae477ba40a064d27cf149aa211ff6108fe239.tar.bz2
spark-8e7ae477ba40a064d27cf149aa211ff6108fe239.zip
[SPARK-2314][SQL] Override collect and take in python library, and count in java library, with optimized versions.
SchemaRDD overrides RDD functions, including collect, count, and take, with optimized versions making use of the query optimizer. The java and python interface classes wrapping SchemaRDD need to ensure the optimized versions are called as well. This patch overrides relevant calls in the python and java interfaces with optimized versions. Adds a new Row serialization pathway between python and java, based on JList[Array[Byte]] versus the existing RDD[Array[Byte]]. I wasn’t overjoyed about doing this, but I noticed that some QueryPlans implement optimizations in executeCollect(), which outputs an Array[Row] rather than the typical RDD[Row] that can be shipped to python using the existing serialization code. To me it made sense to ship the Array[Row] over to python directly instead of converting it back to an RDD[Row] just for the purpose of sending the Rows to python using the existing serialization code. Author: Aaron Staple <aaron.staple@gmail.com> Closes #1592 from staple/SPARK-2314 and squashes the following commits: 89ff550 [Aaron Staple] Merge with master. 6bb7b6c [Aaron Staple] Fix typo. b56d0ac [Aaron Staple] [SPARK-2314][SQL] Override count in JavaSchemaRDD, forwarding to SchemaRDD's count. 0fc9d40 [Aaron Staple] Fix comment typos. f03cdfa [Aaron Staple] [SPARK-2314][SQL] Override collect and take in sql.py, forwarding to SchemaRDD's collect.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index eac55cbe15..621a556ec6 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -30,6 +30,7 @@ from operator import itemgetter
from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
from itertools import chain, ifilter, imap
@@ -1550,6 +1551,18 @@ class SchemaRDD(RDD):
self._id = self._jrdd.id()
return self._id
+ def limit(self, num):
+ """Limit the result count to the number specified.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.limit(2).collect()
+ [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ >>> srdd.limit(0).collect()
+ []
+ """
+ rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
+ return SchemaRDD(rdd, self.sql_ctx)
+
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
@@ -1626,15 +1639,39 @@ class SchemaRDD(RDD):
return self._jschema_rdd.count()
def collect(self):
- """
- Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows in this RDD.
- Each object in the list is on Row, the fields can be accessed as
+ Each object in the list is a Row, the fields can be accessed as
attributes.
+
+ Unlike the base RDD implementation of collect, this implementation
+ leverages the query optimizer to perform a collect on the SchemaRDD,
+ which supports features such as filter pushdown.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.collect()
+ [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
- rows = RDD.collect(self)
+ with SCCallSiteSync(self.context) as css:
+ bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
cls = _create_cls(self.schema())
- return map(cls, rows)
+ return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+ def take(self, num):
+ """Take the first num rows of the RDD.
+
+ Each object in the list is a Row, the fields can be accessed as
+ attributes.
+
+ Unlike the base RDD implementation of take, this implementation
+ leverages the query optimizer to perform a collect on a SchemaRDD,
+ which supports features such as filter pushdown.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.take(2)
+ [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ """
+ return self.limit(num).collect()
# Convert each object in the RDD to a Row with the right class
# for this SchemaRDD, so that fields can be accessed as attributes.