aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-12-21 13:46:58 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-21 13:46:58 -0800
commit4883a5087d481d4de5d3beabbd709853de01399a (patch)
tree63cd07b7cf6447ab0b2ca77503789ef7a7b06d2f /sql
parent7634fe9511e1a8fb94979624b1b617b495b48ad3 (diff)
downloadspark-4883a5087d481d4de5d3beabbd709853de01399a.tar.gz
spark-4883a5087d481d4de5d3beabbd709853de01399a.tar.bz2
spark-4883a5087d481d4de5d3beabbd709853de01399a.zip
[SPARK-12374][SPARK-12150][SQL] Adding logical/physical operators for Range
Based on the suggestions from marmbrus , added logical/physical operators for Range for improving the performance. Also added another API for resolving the JIRA Spark-12150. Could you take a look at my implementation, marmbrus ? If not good, I can rework it. : ) Thank you very much! Author: gatorsmile <gatorsmile@gmail.com> Closes #10335 from gatorsmile/rangeOperators.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala1
6 files changed, 118 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ec42b763f1..64ef4d7996 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -210,6 +210,38 @@ case class Sort(
override def output: Seq[Attribute] = child.output
}
+/** Factory for constructing new `Range` nodes. */
+object Range {
+ def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+ val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
+ new Range(start, end, step, numSlices, output)
+ }
+}
+
+case class Range(
+ start: Long,
+ end: Long,
+ step: Long,
+ numSlices: Int,
+ output: Seq[Attribute]) extends LeafNode {
+ require(step != 0, "step cannot be 0")
+ val numElements: BigInt = {
+ val safeStart = BigInt(start)
+ val safeEnd = BigInt(end)
+ if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) {
+ (safeEnd - safeStart) / step
+ } else {
+ // the remainder has the same sign with range, could add 1 more
+ (safeEnd - safeStart) / step + 1
+ }
+ }
+
+ override def statistics: Statistics = {
+ val sizeInBytes = LongType.defaultSize * numElements
+ Statistics( sizeInBytes = sizeInBytes )
+ }
+}
+
case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index db286ea870..eadf5cba6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
@@ -785,9 +785,20 @@ class SQLContext private[sql](
*/
@Experimental
def range(start: Long, end: Long): DataFrame = {
- createDataFrame(
- sparkContext.range(start, end).map(Row(_)),
- StructType(StructField("id", LongType, nullable = false) :: Nil))
+ range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with an step value.
+ *
+ * @since 2.0.0
+ * @group dataframe
+ */
+ @Experimental
+ def range(start: Long, end: Long, step: Long): DataFrame = {
+ range(start, end, step, numPartitions = sparkContext.defaultParallelism)
}
/**
@@ -801,9 +812,7 @@ class SQLContext private[sql](
*/
@Experimental
def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
- createDataFrame(
- sparkContext.range(start, end, step, numPartitions).map(Row(_)),
- StructType(StructField("id", LongType, nullable = false) :: Nil))
+ DataFrame(this, Range(start, end, step, numPartitions))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 688555cf13..183d9b6502 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -358,6 +358,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil
+ case r @ logical.Range(start, end, step, numSlices, output) =>
+ execution.Range(start, step, numSlices, r.numElements, output) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
execution.Exchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b3e4688557..21325beb1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.LongType
import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.PoissonSampler
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -126,6 +127,67 @@ case class Sample(
}
}
+case class Range(
+ start: Long,
+ step: Long,
+ numSlices: Int,
+ numElements: BigInt,
+ output: Seq[Attribute])
+ extends LeafNode {
+
+ override def outputsUnsafeRows: Boolean = true
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ sqlContext
+ .sparkContext
+ .parallelize(0 until numSlices, numSlices)
+ .mapPartitionsWithIndex((i, _) => {
+ val partitionStart = (i * numElements) / numSlices * step + start
+ val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
+ def getSafeMargin(bi: BigInt): Long =
+ if (bi.isValidLong) {
+ bi.toLong
+ } else if (bi > 0) {
+ Long.MaxValue
+ } else {
+ Long.MinValue
+ }
+ val safePartitionStart = getSafeMargin(partitionStart)
+ val safePartitionEnd = getSafeMargin(partitionEnd)
+ val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
+ val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
+
+ new Iterator[InternalRow] {
+ private[this] var number: Long = safePartitionStart
+ private[this] var overflow: Boolean = false
+
+ override def hasNext =
+ if (!overflow) {
+ if (step > 0) {
+ number < safePartitionEnd
+ } else {
+ number > safePartitionEnd
+ }
+ } else false
+
+ override def next() = {
+ val ret = number
+ number += step
+ if (number < ret ^ step < 0) {
+ // we have Long.MaxValue + Long.MaxValue < Long.MaxValue
+ // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
+ // back, we are pretty sure that we have an overflow.
+ overflow = true
+ }
+
+ unsafeRow.setLong(0, ret)
+ unsafeRow
+ }
+ }
+ })
+ }
+}
+
/**
* Union two plans, without a distinct. This is UNION ALL in SQL.
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1a0f1b61cb..ad478b0511 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -769,6 +769,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val res11 = sqlContext.range(-1).select("id")
assert(res11.count == 0)
+
+ // using the default slice number
+ val res12 = sqlContext.range(3, 15, 3).select("id")
+ assert(res12.count == 4)
+ assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
}
test("SPARK-8621: support empty string column name") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 180050bdac..101cf50d80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -260,6 +260,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
.set("spark.driver.allowMultipleContexts", "true")
.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
.set(
SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key,
targetNumPostShufflePartitions.toString)