aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala40
2 files changed, 71 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ac1a800219..316ef7d588 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -685,6 +685,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
/**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end`(exclusive) with step value 1.
+ *
+ * @since 1.4.0
+ * @group dataframe
+ */
+ @Experimental
+ def range(start: Long, end: Long): DataFrame = {
+ createDataFrame(
+ sparkContext.range(start, end).map(Row(_)),
+ StructType(StructField("id", LongType, nullable = false) :: Nil))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end`(exclusive) with an step value, with partition number
+ * specified.
+ *
+ * @since 1.4.0
+ * @group dataframe
+ */
+ @Experimental
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
+ createDataFrame(
+ sparkContext.range(start, end, step, numPartitions).map(Row(_)),
+ StructType(StructField("id", LongType, nullable = false) :: Nil))
+ }
+
+ /**
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 054b23dba8..f05d059d44 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest {
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
assert(!p.child.isInstanceOf[Project])
}
+
+ test("SPARK-7150 range api") {
+ // numSlice is greater than length
+ val res1 = TestSQLContext.range(0, 10, 1, 15).select("id")
+ assert(res1.count == 10)
+ assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res2 = TestSQLContext.range(3, 15, 3, 2).select("id")
+ assert(res2.count == 4)
+ assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ val res3 = TestSQLContext.range(1, -2).select("id")
+ assert(res3.count == 0)
+
+ // start is positive, end is negative, step is negative
+ val res4 = TestSQLContext.range(1, -2, -2, 6).select("id")
+ assert(res4.count == 2)
+ assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+ // start, end, step are negative
+ val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id")
+ assert(res5.count == 3)
+ assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+ // start, end are negative, step is positive
+ val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id")
+ assert(res6.count == 2)
+ assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+ val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id")
+ assert(res7.count == 0)
+
+ val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ assert(res8.count == 3)
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+ val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ assert(res9.count == 2)
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+ }
}