aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-14 10:10:01 -0700
committerDavies Liu <davies.liu@gmail.com>2016-09-14 10:10:01 -0700
commit6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9 (patch)
treecee2c6043fc889682ec3827f10818ecb85502af0 /python/pyspark/sql
parent52738d4e099a19466ef909b77c24cab109548706 (diff)
downloadspark-6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9.tar.gz
spark-6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9.tar.bz2
spark-6d06ff6f7e2dd72ba8fe96cd875e83eda6ebb2a9.zip
[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the same in Python
## What changes were proposed in this pull request? In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing. The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.). In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations. ## How was this patch tested? Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries. Author: Josh Rosen <joshrosen@databricks.com> Closes #15068 from JoshRosen/pyspark-collect-limit.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/dataframe.py5
-rw-r--r--python/pyspark/sql/tests.py18
2 files changed, 19 insertions, 4 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index e5eac918a9..0f7d8fba3b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -357,10 +357,7 @@ class DataFrame(object):
>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
- with SCCallSiteSync(self._sc) as css:
- port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
- self._jdf, num)
- return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
+ return self.limit(num).collect()
@since(1.3)
def foreach(self, f):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 769e454072..1be0b72304 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1862,6 +1862,24 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
["1", "2", "2", "2"])
+ def test_limit_and_take(self):
+ df = self.spark.range(1, 1000, numPartitions=10)
+
+ def assert_runs_only_one_job_stage_and_task(job_group_name, f):
+ tracker = self.sc.statusTracker()
+ self.sc.setJobGroup(job_group_name, description="")
+ f()
+ jobs = tracker.getJobIdsForGroup(job_group_name)
+ self.assertEqual(1, len(jobs))
+ stages = tracker.getJobInfo(jobs[0]).stageIds
+ self.assertEqual(1, len(stages))
+ self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
+
+ # Regression test for SPARK-10731: take should delegate to Scala implementation
+ assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
+ # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
+ assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
+
if __name__ == "__main__":
from pyspark.sql.tests import *