aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authoranimesh <animesh@apache.spark>2015-06-03 11:28:18 -0700
committerReynold Xin <rxin@databricks.com>2015-06-03 11:28:18 -0700
commitd053a31be93d789e3f26cf55d747ecf6ca386c29 (patch)
treed864eda0ced1151c00404197173afbf49acf0fb6 /python
parent2c4d550eda0e6f33d2d575825c3faef4c9217067 (diff)
downloadspark-d053a31be93d789e3f26cf55d747ecf6ca386c29.tar.gz
spark-d053a31be93d789e3f26cf55d747ecf6ca386c29.tar.bz2
spark-d053a31be93d789e3f26cf55d747ecf6ca386c29.zip
[SPARK-7980] [SQL] Support SQLContext.range(end)
1. range() overloaded in SQLContext.scala 2. range() modified in python sql context.py 3. Tests added accordingly in DataFrameSuite.scala and python sql tests.py Author: animesh <animesh@apache.spark> Closes #6609 from animeshbaranawal/SPARK-7980 and squashes the following commits: 935899c [animesh] SPARK-7980:python+scala changes
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/context.py12
-rw-r--r--python/pyspark/sql/tests.py2
2 files changed, 12 insertions, 2 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9fdf43c3e6..1bebfc4837 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -131,7 +131,7 @@ class SQLContext(object):
return UDFRegistration(self)
@since(1.4)
- def range(self, start, end, step=1, numPartitions=None):
+ def range(self, start, end=None, step=1, numPartitions=None):
"""
Create a :class:`DataFrame` with single LongType column named `id`,
containing elements in a range from `start` to `end` (exclusive) with
@@ -145,10 +145,18 @@ class SQLContext(object):
>>> sqlContext.range(1, 7, 2).collect()
[Row(id=1), Row(id=3), Row(id=5)]
+
+ >>> sqlContext.range(3).collect()
+ [Row(id=0), Row(id=1), Row(id=2)]
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
- jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
+ if end is None:
+ jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
+ else:
+ jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
return DataFrame(jdf, self)
@ignore_unicode_prefix
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6e498f0af0..a6fce50c76 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -131,6 +131,8 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+ self.assertEqual(self.sqlCtx.range(-2).count(), 0)
+ self.assertEqual(self.sqlCtx.range(3).count(), 3)
def test_explode(self):
from pyspark.sql.functions import explode