diff options
author | Daoyuan Wang <daoyuan.wang@intel.com> | 2015-05-18 21:43:12 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-05-18 21:43:25 -0700 |
commit | 7fcbb2ccaf50d7cb1dc68ff0c271737a3a59253e (patch) | |
tree | 37f9f07786ecc9436672614bce0b70d7fed81878 /sql/core | |
parent | 9d0b7fb714a8ce3437dbc6d19b9d7b38f1db3c73 (diff) | |
download | spark-7fcbb2ccaf50d7cb1dc68ff0c271737a3a59253e.tar.gz spark-7fcbb2ccaf50d7cb1dc68ff0c271737a3a59253e.tar.bz2 spark-7fcbb2ccaf50d7cb1dc68ff0c271737a3a59253e.zip |
[SPARK-7150] SparkContext.range() and SQLContext.range()
This PR is based on #6081, thanks adrian-wang.
Closes #6081
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Author: Davies Liu <davies@databricks.com>
Closes #6230 from davies/range and squashes the following commits:
d3ce5fe [Davies Liu] add tests
789eda5 [Davies Liu] add range() in Python
4590208 [Davies Liu] Merge commit 'refs/pull/6081/head' of github.com:apache/spark into range
cbf5200 [Daoyuan Wang] let's add python support in a separate PR
f45e3b2 [Daoyuan Wang] remove redundant toLong
617da76 [Daoyuan Wang] fix safe marge for corner cases
867c417 [Daoyuan Wang] fix
13dbe84 [Daoyuan Wang] update
bd998ba [Daoyuan Wang] update comments
d3a0c1b [Daoyuan Wang] add range api()
(cherry picked from commit c2437de1899e09894df4ec27adfaa7fac158fd3a)
Signed-off-by: Reynold Xin <rxin@databricks.com>
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 31 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 40 |
2 files changed, 71 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ac1a800219..316ef7d588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -685,6 +685,37 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with step value 1. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long): DataFrame = { + createDataFrame( + sparkContext.range(start, end).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with an step value, with partition number + * specified. + * + * @since 1.4.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step, numPartitions).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) + } + + /** * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 054b23dba8..f05d059d44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest { val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] assert(!p.child.isInstanceOf[Project]) } + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = TestSQLContext.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + } } |