aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py47
1 files changed, 42 insertions, 5 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index eac55cbe15..621a556ec6 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -30,6 +30,7 @@ from operator import itemgetter
from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
+from pyspark.traceback_utils import SCCallSiteSync
from itertools import chain, ifilter, imap
@@ -1550,6 +1551,18 @@ class SchemaRDD(RDD):
self._id = self._jrdd.id()
return self._id
+ def limit(self, num):
+ """Limit the result count to the number specified.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.limit(2).collect()
+ [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ >>> srdd.limit(0).collect()
+ []
+ """
+ rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
+ return SchemaRDD(rdd, self.sql_ctx)
+
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
@@ -1626,15 +1639,39 @@ class SchemaRDD(RDD):
return self._jschema_rdd.count()
def collect(self):
- """
- Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows in this RDD.
- Each object in the list is on Row, the fields can be accessed as
+ Each object in the list is a Row, the fields can be accessed as
attributes.
+
+ Unlike the base RDD implementation of collect, this implementation
+ leverages the query optimizer to perform a collect on the SchemaRDD,
+ which supports features such as filter pushdown.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.collect()
+ [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
- rows = RDD.collect(self)
+ with SCCallSiteSync(self.context) as css:
+ bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
cls = _create_cls(self.schema())
- return map(cls, rows)
+ return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+ def take(self, num):
+ """Take the first num rows of the RDD.
+
+ Each object in the list is a Row, the fields can be accessed as
+ attributes.
+
+ Unlike the base RDD implementation of take, this implementation
+ leverages the query optimizer to perform a collect on a SchemaRDD,
+ which supports features such as filter pushdown.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.take(2)
+ [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ """
+ return self.limit(num).collect()
# Convert each object in the RDD to a Row with the right class
# for this SchemaRDD, so that fields can be accessed as attributes.