From d053a31be93d789e3f26cf55d747ecf6ca386c29 Mon Sep 17 00:00:00 2001 From: animesh Date: Wed, 3 Jun 2015 11:28:18 -0700 Subject: [SPARK-7980] [SQL] Support SQLContext.range(end) 1. range() overloaded in SQLContext.scala 2. range() modified in python sql context.py 3. Tests added accordingly in DataFrameSuite.scala and python sql tests.py Author: animesh Closes #6609 from animeshbaranawal/SPARK-7980 and squashes the following commits: 935899c [animesh] SPARK-7980:python+scala changes --- python/pyspark/sql/context.py | 12 ++++++++++-- python/pyspark/sql/tests.py | 2 ++ 2 files changed, 12 insertions(+), 2 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9fdf43c3e6..1bebfc4837 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -131,7 +131,7 @@ class SQLContext(object): return UDFRegistration(self) @since(1.4) - def range(self, start, end, step=1, numPartitions=None): + def range(self, start, end=None, step=1, numPartitions=None): """ Create a :class:`DataFrame` with single LongType column named `id`, containing elements in a range from `start` to `end` (exclusive) with @@ -145,10 +145,18 @@ class SQLContext(object): >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] + + >>> sqlContext.range(3).collect() + [Row(id=0), Row(id=1), Row(id=2)] """ if numPartitions is None: numPartitions = self._sc.defaultParallelism - jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + + if end is None: + jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) + else: + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6e498f0af0..a6fce50c76 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -131,6 +131,8 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.sqlCtx.range(-2).count(), 0) + self.assertEqual(self.sqlCtx.range(3).count(), 3) def test_explode(self): from pyspark.sql.functions import explode -- cgit v1.2.3