aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-07-29 20:58:05 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-29 20:58:05 -0700
commit84467468d466dadf4708a7d6a808471305149713 (patch)
treeab229d72541e2e0162ded0045694d5d1a09e2a08 /sql/core/src
parent22649b6cde8e18f043f122bce46f446174d00f6c (diff)
downloadspark-84467468d466dadf4708a7d6a808471305149713.tar.gz
spark-84467468d466dadf4708a7d6a808471305149713.tar.bz2
spark-84467468d466dadf4708a7d6a808471305149713.zip
[SPARK-2054][SQL] Code Generation for Expression Evaluation
Adds a new method for evaluating expressions using code that is generated though Scala reflection. This functionality is configured by the SQLConf option `spark.sql.codegen` and is currently turned off by default. Evaluation can be done in several specialized ways: - *Projection* - Given an input row, produce a new row from a set of expressions that define each column in terms of the input row. This can either produce a new Row object or perform the projection in-place on an existing Row (MutableProjection). - *Ordering* - Compares two rows based on a list of `SortOrder` expressions - *Condition* - Returns `true` or `false` given an input row. For each of the above operations there is both a Generated and Interpreted version. When generation for a given expression type is undefined, the code generator falls back on calling the `eval` function of the expression class. Even without custom code, there is still a potential speed up, as loops are unrolled and code can still be inlined by JIT. This PR also contains a new type of Aggregation operator, `GeneratedAggregate`, that performs aggregation by using generated `Projection` code. Currently the required expression rewriting only works for simple aggregations like `SUM` and `COUNT`. This functionality will be extended in a future PR. This PR also performs several clean ups that simplified the implementation: - The notion of `Binding` all expressions in a tree automatically before query execution has been removed. Instead it is the responsibly of an operator to provide the input schema when creating one of the specialized evaluators defined above. In cases when the standard eval method is going to be called, binding can still be done manually using `BindReferences`. There are a few reasons for this change: First, there were many operators where it just didn't work before. For example, operators with more than one child, and operators like aggregation that do significant rewriting of the expression. Second, the semantics of equality with `BoundReferences` are broken. Specifically, we have had a few bugs where partitioning breaks because of the binding. - A copy of the current `SQLContext` is automatically propagated to all `SparkPlan` nodes by the query planner. Before this was done ad-hoc for the nodes that needed this. However, this required a lot of boilerplate as one had to always remember to make it `transient` and also had to modify the `otherCopyArgs`. Author: Michael Armbrust <michael@databricks.com> Closes #993 from marmbrus/newCodeGen and squashes the following commits: 96ef82c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen f34122d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen 67b1c48 [Michael Armbrust] Use conf variable in SQLConf object 4bdc42c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 41a40c9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen de22aac [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen fed3634 [Michael Armbrust] Inspectors are not serializable. ef8d42b [Michael Armbrust] comments 533fdfd [Michael Armbrust] More logging of expression rewriting for GeneratedAggregate. 3cd773e [Michael Armbrust] Allow codegen for Generate. 64b2ee1 [Michael Armbrust] Implement copy 3587460 [Michael Armbrust] Drop unused string builder function. 9cce346 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 1a61293 [Michael Armbrust] Address review comments. 0672e8a [Michael Armbrust] Address comments. 1ec2d6e [Michael Armbrust] Address comments 033abc6 [Michael Armbrust] off by default 4771fab [Michael Armbrust] Docs, more test coverage. d30fee2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen d2ad5c5 [Michael Armbrust] Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases. be2cd6b [Michael Armbrust] WIP: Remove old method for reference binding, more work on configuration. bc88ecd [Michael Armbrust] Style 6cc97ca [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 4220f1e [Michael Armbrust] Better config, docs, etc. ca6cc6b [Michael Armbrust] WIP 9d67d85 [Michael Armbrust] Fix hive planner fc522d5 [Michael Armbrust] Hook generated aggregation in to the planner. e742640 [Michael Armbrust] Remove unneeded changes and code. 675e679 [Michael Armbrust] Upgrade paradise. 0093376 [Michael Armbrust] Comment / indenting cleanup. d81f998 [Michael Armbrust] include schema for binding. 0e889e8 [Michael Armbrust] Use typeOf instead tq f623ffd [Michael Armbrust] Quiet logging from test suite. efad14f [Michael Armbrust] Remove some half finished functions. 92e74a4 [Michael Armbrust] add overrides a2b5408 [Michael Armbrust] WIP: Code generation with scala reflection.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala200
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala138
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala5
19 files changed, 474 insertions, 180 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 5d85a0fd4e..2d407077be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -24,8 +24,11 @@ import scala.collection.JavaConverters._
object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
- val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
+ val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
+ val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
+ val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
+ val CODEGEN_ENABLED = "spark.sql.codegen"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -57,6 +60,18 @@ trait SQLConf {
private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
/**
+ * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
+ * that evaluates expressions found in queries. In general this custom code runs much faster
+ * than interpreted evaluation, but there are significant start-up costs due to compilation.
+ * As a result codegen is only benificial when queries run for a long time, or when the same
+ * expressions are used multiple times.
+ *
+ * Defaults to false as this feature is currently experimental.
+ */
+ private[spark] def codegenEnabled: Boolean =
+ if (get(CODEGEN_ENABLED, "false") == "true") true else false
+
+ /**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
* effectively disables auto conversion.
@@ -111,5 +126,5 @@ trait SQLConf {
private[spark] def clear() {
settings.clear()
}
-
}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c2bdef7323..e4b6810180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration)))
+ new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
* Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]].
@@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
conf: Configuration = new Configuration()): SchemaRDD = {
new SchemaRDD(
this,
- ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf))
+ ParquetRelation.createEmpty(
+ path, ScalaReflection.attributesFor[A], allowExisting, conf, this))
}
/**
@@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
val sqlContext: SQLContext = self
+ def codegenEnabled = self.codegenEnabled
+
def numPartitions = self.numShufflePartitions
val strategies: Seq[Strategy] =
CommandStrategy(self) ::
TakeOrdered ::
- PartialAggregation ::
+ HashAggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
@@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1)
/**
- * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
- * inserting shuffle operations as needed.
+ * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
*/
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange(self)) ::
- Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil
+ Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
/**
+ * :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
* access to the intermediate phases of query execution for developers.
*/
+ @DeveloperApi
protected abstract class QueryExecution {
def logical: LogicalPlan
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
- lazy val sparkPlan = planner(optimizedPlan).next()
+ lazy val sparkPlan = {
+ SparkPlan.currentContext.set(self)
+ planner(optimizedPlan).next()
+ }
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
@@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
|${stringOrError(optimizedPlan)}
|== Physical Plan ==
|${stringOrError(executedPlan)}
+ |Code Generation: ${executedPlan.codegenEnabled}
+ |== RDD ==
+ |${stringOrError(toRdd.toDebugString)}
""".stripMargin.trim
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 806097c917..85726bae54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
conf: Configuration = new Configuration()): JavaSchemaRDD = {
new JavaSchemaRDD(
sqlContext,
- ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf))
+ ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext))
}
/**
@@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
def parquetFile(path: String): JavaSchemaRDD =
new JavaSchemaRDD(
sqlContext,
- ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration)))
+ ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext))
/**
* Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index c1ced8bfa4..463a1d32d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -42,8 +42,8 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)(@transient sqlContext: SQLContext)
- extends UnaryNode with NoBind {
+ child: SparkPlan)
+ extends UnaryNode {
override def requiredChildDistribution =
if (partial) {
@@ -56,8 +56,6 @@ case class Aggregate(
}
}
- override def otherCopyArgs = sqlContext :: Nil
-
// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
private[this] val childOutput = child.output
@@ -138,7 +136,7 @@ case class Aggregate(
i += 1
}
}
- val resultProjection = new Projection(resultExpressions, computedSchema)
+ val resultProjection = new InterpretedProjection(resultExpressions, computedSchema)
val aggregateResults = new GenericMutableRow(computedAggregates.length)
var i = 0
@@ -152,7 +150,7 @@ case class Aggregate(
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
- val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
+ val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput)
var currentRow: Row = null
while (iter.hasNext) {
@@ -175,7 +173,8 @@ case class Aggregate(
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
- new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
+ new InterpretedMutableProjection(
+ resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow
override final def hasNext: Boolean = hashTableIter.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 00010ef6e7..392a7f3be3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
@@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind {
+case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
override def outputPartitioning = newPartitioning
@@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
val rdd = child.execute().mapPartitions { iter =>
- val hashExpressions = new MutableProjection(expressions, child.output)
+ @transient val hashExpressions =
+ newMutableProjection(expressions, child.output)()
+
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 47b3d00262..c386fd121c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -47,23 +47,26 @@ case class Generate(
}
}
- override def output =
+ // This must be a val since the generator output expr ids are not preserved by serialization.
+ override val output =
if (join) child.output ++ generatorOutput else generatorOutput
+ val boundGenerator = BindReferences.bindReference(generator, child.output)
+
override def execute() = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
- new Projection(child.output ++ nullValues, child.output)
+ newProjection(child.output ++ nullValues, child.output)
val joinProjection =
- new Projection(child.output ++ generator.output, child.output ++ generator.output)
+ newProjection(child.output ++ generator.output, child.output ++ generator.output)
val joinedRow = new JoinedRow
iter.flatMap {row =>
- val outputRows = generator.eval(row)
+ val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
@@ -72,7 +75,7 @@ case class Generate(
}
}
} else {
- child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
+ child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
new file mode 100644
index 0000000000..4a26934c49
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types._
+
+case class AggregateEvaluation(
+ schema: Seq[Attribute],
+ initialValues: Seq[Expression],
+ update: Seq[Expression],
+ result: Expression)
+
+/**
+ * :: DeveloperApi ::
+ * Alternate version of aggregation that leverages projection and thus code generation.
+ * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
+ * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
+ *
+ * @param partial if true then aggregation is done partially on local data without shuffling to
+ * ensure all values where `groupingExpressions` are equal are present.
+ * @param groupingExpressions expressions that are evaluated to determine grouping.
+ * @param aggregateExpressions expressions that are computed for each group.
+ * @param child the input data source.
+ */
+@DeveloperApi
+case class GeneratedAggregate(
+ partial: Boolean,
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def requiredChildDistribution =
+ if (partial) {
+ UnspecifiedDistribution :: Nil
+ } else {
+ if (groupingExpressions == Nil) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def output = aggregateExpressions.map(_.toAttribute)
+
+ override def execute() = {
+ val aggregatesToCompute = aggregateExpressions.flatMap { a =>
+ a.collect { case agg: AggregateExpression => agg}
+ }
+
+ val computeFunctions = aggregatesToCompute.map {
+ case c @ Count(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val initialValue = Literal(0L)
+ val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val result = currentCount
+
+ AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case Sum(expr) =>
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialValue = Cast(Literal(0L), expr.dataType)
+
+ // Coalasce avoids double calculation...
+ // but really, common sub expression elimination would be better....
+ val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+ val result = currentSum
+
+ AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case a @ Average(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialCount = Literal(0L)
+ val initialSum = Cast(Literal(0L), expr.dataType)
+ val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+
+ val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+
+ AggregateEvaluation(
+ currentCount :: currentSum :: Nil,
+ initialCount :: initialSum :: Nil,
+ updateCount :: updateSum :: Nil,
+ result
+ )
+ }
+
+ val computationSchema = computeFunctions.flatMap(_.schema)
+
+ val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
+ case (agg, func) => agg.id -> func.result
+ }.toMap
+
+ val namedGroups = groupingExpressions.zipWithIndex.map {
+ case (ne: NamedExpression, _) => (ne, ne)
+ case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+ }
+
+ val groupMap: Map[Expression, Attribute] =
+ namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
+
+ // The set of expressions that produce the final output given the aggregation buffer and the
+ // grouping expressions.
+ val resultExpressions = aggregateExpressions.map(_.transform {
+ case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+ case e: Expression if groupMap.contains(e) => groupMap(e)
+ })
+
+ child.execute().mapPartitions { iter =>
+ // Builds a new custom class for holding the results of aggregation for a group.
+ val initialValues = computeFunctions.flatMap(_.initialValues)
+ val newAggregationBuffer = newProjection(initialValues, child.output)
+ log.info(s"Initial values: ${initialValues.mkString(",")}")
+
+ // A projection that computes the group given an input tuple.
+ val groupProjection = newProjection(groupingExpressions, child.output)
+ log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
+
+ // A projection that is used to update the aggregate values for a group given a new tuple.
+ // This projection should be targeted at the current values for the group and then applied
+ // to a joined row of the current values with the new input row.
+ val updateExpressions = computeFunctions.flatMap(_.update)
+ val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
+ val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
+ log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
+
+ // A projection that produces the final result, given a computation.
+ val resultProjectionBuilder =
+ newMutableProjection(
+ resultExpressions,
+ (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+ log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
+
+ val joinedRow = new JoinedRow
+
+ if (groupingExpressions.isEmpty) {
+ // TODO: Codegening anything other than the updateProjection is probably over kill.
+ val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ var currentRow: Row = null
+ updateProjection.target(buffer)
+
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ updateProjection(joinedRow(buffer, currentRow))
+ }
+
+ val resultProjection = resultProjectionBuilder()
+ Iterator(resultProjection(buffer))
+ } else {
+ val buffers = new java.util.HashMap[Row, MutableRow]()
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ val currentGroup = groupProjection(currentRow)
+ var currentBuffer = buffers.get(currentGroup)
+ if (currentBuffer == null) {
+ currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ buffers.put(currentGroup, currentBuffer)
+ }
+ // Target the projection at the current aggregation buffer and then project the updated
+ // values.
+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
+ }
+
+ new Iterator[Row] {
+ private[this] val resultIterator = buffers.entrySet.iterator()
+ private[this] val resultProjection = resultProjectionBuilder()
+
+ def hasNext = resultIterator.hasNext
+
+ def next() = {
+ val currentGroup = resultIterator.next()
+ resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 77c874d031..21cbbc9772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -18,22 +18,55 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Logging, Row, SQLContext}
+
+
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
+
+object SparkPlan {
+ protected[sql] val currentContext = new ThreadLocal[SQLContext]()
+}
+
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
+abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
self: Product =>
+ /**
+ * A handle to the SQL Context that was used to create this plan. Since many operators need
+ * access to the sqlContext for RDD operations or configuration this field is automatically
+ * populated by the query planning infrastructure.
+ */
+ @transient
+ protected val sqlContext = SparkPlan.currentContext.get()
+
+ protected def sparkContext = sqlContext.sparkContext
+
+ // sqlContext will be null when we are being deserialized on the slaves. In this instance
+ // the value of codegenEnabled will be set by the desserializer after the constructor has run.
+ val codegenEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.codegenEnabled
+ } else {
+ false
+ }
+
+ /** Overridden make copy also propogates sqlContext to copied plan. */
+ override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ SparkPlan.currentContext.set(sqlContext)
+ super.makeCopy(newArgs)
+ }
+
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
@@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
*/
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
- protected def buildRow(values: Seq[Any]): Row =
- new GenericRow(values.toArray)
+ protected def newProjection(
+ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
+ log.debug(
+ s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if (codegenEnabled) {
+ GenerateProjection(expressions, inputSchema)
+ } else {
+ new InterpretedProjection(expressions, inputSchema)
+ }
+ }
+
+ protected def newMutableProjection(
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): () => MutableProjection = {
+ log.debug(
+ s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if(codegenEnabled) {
+ GenerateMutableProjection(expressions, inputSchema)
+ } else {
+ () => new InterpretedMutableProjection(expressions, inputSchema)
+ }
+ }
+
+
+ protected def newPredicate(
+ expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
+ if (codegenEnabled) {
+ GeneratePredicate(expression, inputSchema)
+ } else {
+ InterpretedPredicate(expression, inputSchema)
+ }
+ }
+
+ protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
+ if (codegenEnabled) {
+ GenerateOrdering(order, inputSchema)
+ } else {
+ new RowOrdering(order, inputSchema)
+ }
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 404d48ae05..5f1fe99f75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.util.Try
-
import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
@@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
@@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
*/
object HashJoin extends Strategy with PredicateHelper {
+
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition: Option[Expression],
side: BuildSide) = {
val broadcastHashJoin = execution.BroadcastHashJoin(
- leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
+ leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
- if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) {
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
BuildRight
} else {
BuildLeft
@@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object PartialAggregation extends Strategy {
+ object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
- // Collect all aggregate expressions.
- val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
- // Collect all aggregate expressions that can be computed partially.
- val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
-
- // Only do partial aggregation if supported by all aggregate expressions.
- if (allAggregates.size == partialAggregates.size) {
- // Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[Long, SplitEvaluation] =
- partialAggregates.map(a => (a.id, a.asPartial)).toMap
-
- // We need to pass all grouping expressions though so the grouping can happen a second
- // time. However some of them might be unnamed so we alias them allowing them to be
- // referenced in the second aggregation.
- val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }.toMap
+ // Aggregations that can be performed in two phases, before and after the shuffle.
- // Replace aggregations with a new expression that computes the result from the already
- // computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
- case e: Expression if partialEvaluations.contains(e.id) =>
- partialEvaluations(e.id).finalEvaluation
- case e: Expression if namedGroupingExpressions.contains(e) =>
- namedGroupingExpressions(e).toAttribute
- }).asInstanceOf[Seq[NamedExpression]]
-
- val partialComputation =
- (namedGroupingExpressions.values ++
- partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
-
- // Construct two phased aggregation.
- execution.Aggregate(
+ // Cases where all aggregates can be codegened.
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child)
+ if canBeCodeGened(
+ allAggregates(partialComputation) ++
+ allAggregates(rewrittenAggregateExpressions)) &&
+ codegenEnabled =>
+ execution.GeneratedAggregate(
partial = false,
- namedGroupingExpressions.values.map(_.toAttribute).toSeq,
+ namedGroupingAttributes,
rewrittenAggregateExpressions,
- execution.Aggregate(
+ execution.GeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
- planLater(child))(sqlContext))(sqlContext) :: Nil
- } else {
- Nil
- }
+ planLater(child))) :: Nil
+
+ // Cases where some aggregate can not be codegened
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child) =>
+ execution.Aggregate(
+ partial = false,
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ execution.Aggregate(
+ partial = true,
+ groupingExpressions,
+ partialComputation,
+ planLater(child))) :: Nil
+
case _ => Nil
}
+
+ def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
+ case _: Sum | _: Count => false
+ case _ => true
+ }
+
+ def allAggregates(exprs: Seq[Expression]) =
+ exprs.flatMap(_.collect { case a: AggregateExpression => a })
}
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), joinType, condition) :: Nil
case _ => Nil
}
}
@@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
- def convertToCatalyst(a: Any): Any = a match {
- case s: Seq[Any] => s.map(convertToCatalyst)
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case other => other
- }
-
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
+ execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
}
@@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// TODO: need to support writing to other types of files. Unify the below code paths.
case logical.WriteToFile(path, child) =>
val relation =
- ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
+ ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
- InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
+ InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
- InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
+ InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
- ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
+ ParquetTableScan(_, relation, filters)) :: Nil
case _ => Nil
}
@@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
+ execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
- val dataAsRdd =
- sparkContext.parallelize(data.map(r =>
- new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
- execution.ExistingRdd(output, dataAsRdd) :: Nil
+ ExistingRdd(
+ output,
+ ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child))(sqlContext) :: Nil
+ execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
- execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
- case logical.Except(left,right) =>
- execution.Except(planLater(left),planLater(right)) :: Nil
+ execution.Union(unionChildren.map(planLater)) :: Nil
+ case logical.Except(left, right) =>
+ execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 966d8f95fc..174eda8f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
override def output = projectList.map(_.toAttribute)
- override def execute() = child.execute().mapPartitions { iter =>
- @transient val reusableProjection = new MutableProjection(projectList)
- iter.map(reusableProjection)
+ @transient lazy val buildProjection = newMutableProjection(projectList, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ val resuableProjection = buildProjection()
+ iter.map(resuableProjection)
}
}
@@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output = child.output
- override def execute() = child.execute().mapPartitions { iter =>
- iter.filter(condition.eval(_).asInstanceOf[Boolean])
+ @transient lazy val conditionEvaluator = newPredicate(condition, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ iter.filter(conditionEvaluator)
}
}
@@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
+case class Union(children: Seq[SparkPlan]) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output = children.head.output
- override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
-
- override def otherCopyArgs = sqlContext :: Nil
+ override def execute() = sparkContext.union(children.map(_.execute()))
}
/**
@@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
* repartition all the data to a single partition to compute the global limit.
*/
@DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
+case class Limit(limit: Int, child: SparkPlan)
extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = child.output
/**
@@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sqlContext: SQLContext) extends UnaryNode {
- override def otherCopyArgs = sqlContext :: Nil
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
override def output = child.output
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
+ val ordering = new RowOrdering(sortOrder, child.output)
+ // TODO: Is this copying for no reason?
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
+ override def execute() = sparkContext.makeRDD(executeCollect(), 1)
}
/**
@@ -189,15 +187,13 @@ case class Sort(
override def requiredChildDistribution =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
override def execute() = attachTree(this, "sort") {
- // TODO: Optimize sorting operation?
child.execute()
- .mapPartitions(
- iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
- preservesPartitioning = true)
+ .mapPartitions( { iterator =>
+ val ordering = newOrdering(sortOrder, child.output)
+ iterator.map(_.copy()).toArray.sorted(ordering).iterator
+ }, preservesPartitioning = true)
}
override def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index c6fbd6d2f6..5ef46c32d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -41,13 +41,13 @@ package object debug {
*/
@DeveloperApi
implicit class DebugQuery(query: SchemaRDD) {
- def debug(implicit sc: SparkContext): Unit = {
+ def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[Long]()
val debugPlan = plan transform {
case s: SparkPlan if !visited.contains(s.id) =>
visited += s.id
- DebugNode(sc, s)
+ DebugNode(s)
}
println(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
@@ -57,9 +57,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(
- @transient sparkContext: SparkContext,
- child: SparkPlan) extends UnaryNode {
+ private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
def references = Set.empty
def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 7d1f11caae..2750ddbce8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide
trait HashJoin {
+ self: SparkPlan =>
+
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
@@ -56,9 +58,9 @@ trait HashJoin {
def output = left.output ++ right.output
- @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
- () => new MutableProjection(streamedKeys, streamedPlan.output)
+ newMutableProjection(streamedKeys, streamedPlan.output)
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.
@@ -217,9 +219,8 @@ case class BroadcastHashJoin(
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
- right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {
+ right: SparkPlan) extends BinaryNode with HashJoin {
- override def otherCopyArgs = sqlContext :: Nil
override def outputPartitioning: Partitioning = left.outputPartitioning
@@ -228,7 +229,7 @@ case class BroadcastHashJoin(
@transient
lazy val broadcastFuture = future {
- sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
+ sparkContext.broadcast(buildPlan.executeCollect())
}
def execute() = {
@@ -248,14 +249,11 @@ case class BroadcastHashJoin(
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
def output = left.output
/** The Streamed Relation */
@@ -271,7 +269,7 @@ case class LeftSemiJoinBNL(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
@@ -300,8 +298,14 @@ case class LeftSemiJoinBNL(
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
def output = left.output ++ right.output
- def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map {
- case (l: Row, r: Row) => buildRow(l ++ r)
+ def execute() = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.cartesian(rightResults).mapPartitions { iter =>
+ val joinedRow = new JoinedRow
+ iter.map(r => joinedRow(r._1, r._2))
+ }
}
}
@@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
@DeveloperApi
case class BroadcastNestedLoopJoin(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = {
joinType match {
case LeftOuter =>
@@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+ val rightNulls = new GenericMutableRow(right.output.size)
streamedIter.foreach { streamedRow =>
var i = 0
@@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin(
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matchedRows += buildRow(streamedRow ++ broadcastedRow)
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
matched = true
includedBroadcastTuples += i
}
@@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin(
}
if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
- matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null))
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
}
}
Iterator((matchedRows, includedBroadcastTuples))
@@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin(
streamedPlusMatches.map(_._2).reduce(_ ++ _)
}
+ val leftNulls = new GenericMutableRow(left.output.size)
val rightOuterMatches: Seq[Row] =
if (joinType == RightOuter || joinType == FullOuter) {
broadcastedRelation.value.zipWithIndex.filter {
case (row, i) => !allIncludedBroadcastTuples.contains(i)
}.map {
- // TODO: Use projection.
- case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row)
+ case (row, _) => new JoinedRow(leftNulls, row)
}
} else {
Vector()
}
// TODO: Breaks lineage.
- sqlContext.sparkContext.union(
- streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
+ sparkContext.union(
+ streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 8c7dbd5eb4..b3bae5db0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
*/
private[sql] case class ParquetRelation(
path: String,
- @transient conf: Option[Configuration] = None)
+ @transient conf: Option[Configuration],
+ @transient sqlContext: SQLContext)
extends LeafNode with MultiInstanceRelation {
self: Product =>
@@ -61,7 +62,7 @@ private[sql] case class ParquetRelation(
/** Attributes */
override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf)
- override def newInstance = ParquetRelation(path).asInstanceOf[this.type]
+ override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
// different instances of the same relation,
@@ -70,6 +71,9 @@ private[sql] case class ParquetRelation(
p.path == path && p.output == output
case _ => false
}
+
+ // TODO: Use data from the footers.
+ override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes)
}
private[sql] object ParquetRelation {
@@ -106,13 +110,14 @@ private[sql] object ParquetRelation {
*/
def create(pathString: String,
child: LogicalPlan,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
if (!child.resolved) {
throw new UnresolvedException[LogicalPlan](
child,
"Attempt to create Parquet table from unresolved child (when schema is not available)")
}
- createEmpty(pathString, child.output, false, conf)
+ createEmpty(pathString, child.output, false, conf, sqlContext)
}
/**
@@ -127,14 +132,15 @@ private[sql] object ParquetRelation {
def createEmpty(pathString: String,
attributes: Seq[Attribute],
allowExisting: Boolean,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
val path = checkPath(pathString, allowExisting, conf)
if (conf.get(ParquetOutputFormat.COMPRESSION) == null) {
conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name())
}
ParquetRelation.enableLogForwarding()
ParquetTypesConverter.writeMetaData(attributes, path, conf)
- new ParquetRelation(path.toString, Some(conf)) {
+ new ParquetRelation(path.toString, Some(conf), sqlContext) {
override val output = attributes
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index ea74320d06..912a9f002b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -55,8 +55,7 @@ case class ParquetTableScan(
// https://issues.apache.org/jira/browse/SPARK-1367
output: Seq[Attribute],
relation: ParquetRelation,
- columnPruningPred: Seq[Expression])(
- @transient val sqlContext: SQLContext)
+ columnPruningPred: Seq[Expression])
extends LeafNode {
override def execute(): RDD[Row] = {
@@ -99,8 +98,6 @@ case class ParquetTableScan(
.filter(_ != null) // Parquet's record filters may produce null values
}
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Applies a (candidate) projection.
*
@@ -110,7 +107,7 @@ case class ParquetTableScan(
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
val success = validateProjection(prunedAttributes)
if (success) {
- ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
+ ParquetTableScan(prunedAttributes, relation, columnPruningPred)
} else {
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
this
@@ -150,8 +147,7 @@ case class ParquetTableScan(
case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
- overwrite: Boolean = false)(
- @transient val sqlContext: SQLContext)
+ overwrite: Boolean = false)
extends UnaryNode with SparkHadoopMapReduceUtil {
/**
@@ -171,7 +167,7 @@ case class InsertIntoParquetTable(
val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
- logger.debug("Initializing MutableRowWriteSupport")
+ log.debug("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
classOf[org.apache.spark.sql.parquet.RowWriteSupport]
@@ -203,8 +199,6 @@ case class InsertIntoParquetTable(
override def output = child.output
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Stores the given Row RDD as a Hadoop file.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index d4599da711..837ea7695d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+import org.apache.spark.sql.test.TestSQLContext
import parquet.example.data.{GroupWriter, Group}
import parquet.example.data.simple.SimpleGroup
@@ -103,7 +104,7 @@ private[sql] object ParquetTestData {
val testDir = Utils.createTempDir()
val testFilterDir = Utils.createTempDir()
- lazy val testData = new ParquetRelation(testDir.toURI.toString)
+ lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext)
val testNestedSchema1 =
// based on blogpost example, source:
@@ -202,8 +203,10 @@ private[sql] object ParquetTestData {
val testNestedDir3 = Utils.createTempDir()
val testNestedDir4 = Utils.createTempDir()
- lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString)
- lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString)
+ lazy val testNestedData1 =
+ new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext)
+ lazy val testNestedData2 =
+ new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext)
def writeFile() = {
testDir.delete()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8e1e1971d9..1fd8d27b34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -45,6 +45,7 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution}
|== Exception ==
|$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 215618e852..76b1724471 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite {
test("count is partially aggregated") {
val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
- val planned = PartialAggregation(query).head
- val aggregations = planned.collect { case a: Aggregate => a }
+ val planned = HashAggregation(query).head
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggregations.size === 2)
}
test("count distinct is not partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
test("mixed aggregates are not partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index e55648b8ed..2cab5e0c44 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._
* Note: this is only a rough example of how TGFs can be expressed, the final version will likely
* involve a lot more sugar for cleaner use in Scala/Java/etc.
*/
-case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator {
+case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator {
def children = input
protected def makeOutput() = 'nameAndAge.string :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 3c911e9a4e..561f5b4a49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.util.Utils
@@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Projection of simple Parquet file") {
+ SparkPlan.currentContext.set(TestSQLContext)
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- Seq())(TestSQLContext)
+ Seq())
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))