aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/context.py16
-rw-r--r--python/pyspark/sql/context.py20
-rw-r--r--python/pyspark/sql/tests.py5
-rw-r--r--python/pyspark/tests.py5
4 files changed, 46 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d25ee85523..1f2b40b29f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -319,6 +319,22 @@ class SparkContext(object):
with SparkContext._lock:
SparkContext._active_spark_context = None
+ def range(self, start, end, step=1, numSlices=None):
+ """
+ Create a new RDD of int containing elements from `start` to `end`
+ (exclusive), increased by `step` every element.
+
+ :param start: the start value
+ :param end: the end value (exclusive)
+ :param step: the incremental step (default: 1)
+ :param numSlices: the number of partitions of the new RDD
+ :return: An RDD of int
+
+ >>> sc.range(1, 7, 2).collect()
+ [1, 3, 5]
+ """
+ return self.parallelize(xrange(start, end, step), numSlices)
+
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD. Using xrange
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 0bde719124..9f26d13235 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -122,6 +122,26 @@ class SQLContext(object):
"""Returns a :class:`UDFRegistration` for UDF registration."""
return UDFRegistration(self)
+ def range(self, start, end, step=1, numPartitions=None):
+ """
+ Create a :class:`DataFrame` with single LongType column named `id`,
+ containing elements in a range from `start` to `end` (exclusive) with
+ step value `step`.
+
+ :param start: the start value
+ :param end: the end value (exclusive)
+ :param step: the incremental step (default: 1)
+ :param numPartitions: the number of partitions of the DataFrame
+ :return: A new DataFrame
+
+ >>> sqlContext.range(1, 7, 2).collect()
+ [Row(id=1), Row(id=3), Row(id=5)]
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+ return DataFrame(jdf, self)
+
@ignore_unicode_prefix
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d37c5dbed7..84ae36f2fd 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,11 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_range(self):
+ self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
+ self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+
def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 5e023f6c53..d8e319994c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -444,6 +444,11 @@ class AddFileTests(PySparkTestCase):
class RDDTests(ReusedPySparkTestCase):
+ def test_range(self):
+ self.assertEqual(self.sc.range(1, 1).count(), 0)
+ self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
+ self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
+
def test_id(self):
rdd = self.sc.parallelize(range(10))
id = rdd.id()