From c2437de1899e09894df4ec27adfaa7fac158fd3a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 18 May 2015 21:43:12 -0700 Subject: [SPARK-7150] SparkContext.range() and SQLContext.range() This PR is based on #6081, thanks adrian-wang. Closes #6081 Author: Daoyuan Wang Author: Davies Liu Closes #6230 from davies/range and squashes the following commits: d3ce5fe [Davies Liu] add tests 789eda5 [Davies Liu] add range() in Python 4590208 [Davies Liu] Merge commit 'refs/pull/6081/head' of github.com:apache/spark into range cbf5200 [Daoyuan Wang] let's add python support in a separate PR f45e3b2 [Daoyuan Wang] remove redundant toLong 617da76 [Daoyuan Wang] fix safe marge for corner cases 867c417 [Daoyuan Wang] fix 13dbe84 [Daoyuan Wang] update bd998ba [Daoyuan Wang] update comments d3a0c1b [Daoyuan Wang] add range api() --- python/pyspark/context.py | 16 ++++++++++++++++ python/pyspark/sql/context.py | 20 ++++++++++++++++++++ python/pyspark/sql/tests.py | 5 +++++ python/pyspark/tests.py | 5 +++++ 4 files changed, 46 insertions(+) (limited to 'python') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d25ee85523..1f2b40b29f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -319,6 +319,22 @@ class SparkContext(object): with SparkContext._lock: SparkContext._active_spark_context = None + def range(self, start, end, step=1, numSlices=None): + """ + Create a new RDD of int containing elements from `start` to `end` + (exclusive), increased by `step` every element. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numSlices: the number of partitions of the new RDD + :return: An RDD of int + + >>> sc.range(1, 7, 2).collect() + [1, 3, 5] + """ + return self.parallelize(xrange(start, end, step), numSlices) + def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. Using xrange diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0bde719124..9f26d13235 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -122,6 +122,26 @@ class SQLContext(object): """Returns a :class:`UDFRegistration` for UDF registration.""" return UDFRegistration(self) + def range(self, start, end, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: A new DataFrame + + >>> sqlContext.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) + @ignore_unicode_prefix def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d37c5dbed7..84ae36f2fd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,11 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_range(self): + self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) + self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) + self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5e023f6c53..d8e319994c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -444,6 +444,11 @@ class AddFileTests(PySparkTestCase): class RDDTests(ReusedPySparkTestCase): + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + def test_id(self): rdd = self.sc.parallelize(range(10)) id = rdd.id() -- cgit v1.2.3