aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala71
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala4
8 files changed, 64 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0c851c2ee2..8de87594c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -181,7 +181,7 @@ class SqlParser extends StandardTokenParsers {
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
- val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
+ val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
withLimit
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 9d16189dee..b39c2b32cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -130,7 +130,7 @@ case class Aggregate(
def references = child.references
}
-case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
def references = limit.references
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 69bbbdc894..f4bf00f4cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -145,7 +145,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val sparkContext = self.sparkContext
val strategies: Seq[Strategy] =
- TopK ::
+ TakeOrdered ::
PartialAggregation ::
HashJoin ::
ParquetOperations ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 915f551fb2..d8e1b970c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -32,7 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[Array[Any]])
+ // This is kinda hacky...
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
+ kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e35ac0b6ca..b3e51fdf75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case other => other
}
- object TopK extends Strategy {
+ object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
+ execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
case _ => Nil
}
}
@@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
sparkContext.parallelize(data.map(r =>
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
execution.ExistingRdd(output, dataAsRdd) :: Nil
- case logical.StopAfter(IntegerLiteral(limit), child) =>
- execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.Limit(limit, planLater(child))(sparkContext) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 65cb8f8bec..524e5022ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -19,27 +19,28 @@ package org.apache.spark.sql.execution
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext
-
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.util.MutablePair
+
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
- def output = projectList.map(_.toAttribute)
+ override def output = projectList.map(_.toAttribute)
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
@transient val reusableProjection = new MutableProjection(projectList)
iter.map(reusableProjection)
}
}
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
- def output = child.output
+ override def output = child.output
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
iter.filter(condition.apply(_).asInstanceOf[Boolean])
}
}
@@ -47,37 +48,59 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
extends UnaryNode {
- def output = child.output
+ override def output = child.output
// TODO: How to pick seed?
- def execute() = child.execute().sample(withReplacement, fraction, seed)
+ override def execute() = child.execute().sample(withReplacement, fraction, seed)
}
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
- def output = children.head.output
- def execute() = sc.union(children.map(_.execute()))
+ override def output = children.head.output
+ override def execute() = sc.union(children.map(_.execute()))
override def otherCopyArgs = sc :: Nil
}
-case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements. Note that the implementation is different depending on whether
+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
+ * invoked using execute, we first take the limit on each partition, and then repartition all the
+ * data to a single partition to compute the global limit.
+ */
+case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
+ // partition local limit -> exchange into one partition -> partition local limit again
+
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
override def executeCollect() = child.execute().map(_.copy()).take(limit)
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = {
+ val rdd = child.execute().mapPartitions { iter =>
+ val mutablePair = new MutablePair[Boolean, Row]()
+ iter.take(limit).map(row => mutablePair.update(false, row))
+ }
+ val part = new HashPartitioner(1)
+ val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
+ shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.mapPartitions(_.take(limit).map(_._2))
+ }
}
-case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
+ * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
+ (@transient sc: SparkContext) extends UnaryNode {
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
@transient
lazy val ordering = new RowOrdering(sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = sc.makeRDD(executeCollect(), 1)
}
@@ -101,7 +124,7 @@ case class Sort(
@transient
lazy val ordering = new RowOrdering(sortOrder)
- def execute() = attachTree(this, "sort") {
+ override def execute() = attachTree(this, "sort") {
// TODO: Optimize sorting operation?
child.execute()
.mapPartitions(
@@ -109,7 +132,7 @@ case class Sort(
preservesPartitioning = true)
}
- def output = child.output
+ override def output = child.output
}
object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
}
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- def execute() = rdd
+ override def execute() = rdd
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 197b557cba..46febbfad0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -188,7 +188,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val hiveContext = self
override val strategies: Seq[Strategy] = Seq(
- TopK,
+ TakeOrdered,
ParquetOperations,
HiveTableScans,
DataSinks,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 490a592a58..b2b03bc790 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -529,7 +529,7 @@ object HiveQl {
val withLimit =
limitClause.map(l => nodeToExpr(l.getChildren.head))
- .map(StopAfter(_, withSort))
+ .map(Limit(_, withSort))
.getOrElse(withSort)
// TOK_INSERT_INTO means to add files to the table.
@@ -602,7 +602,7 @@ object HiveQl {
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_ROWCOUNT", Nil) ::
Token(count, Nil) :: Nil) =>
- StopAfter(Literal(count.toInt), relation)
+ Limit(Literal(count.toInt), relation)
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_PERCENT", Nil) ::
Token(fraction, Nil) :: Nil) =>