aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-07-29 20:58:05 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-29 20:58:05 -0700
commit84467468d466dadf4708a7d6a808471305149713 (patch)
treeab229d72541e2e0162ded0045694d5d1a09e2a08 /sql
parent22649b6cde8e18f043f122bce46f446174d00f6c (diff)
downloadspark-84467468d466dadf4708a7d6a808471305149713.tar.gz
spark-84467468d466dadf4708a7d6a808471305149713.tar.bz2
spark-84467468d466dadf4708a7d6a808471305149713.zip
[SPARK-2054][SQL] Code Generation for Expression Evaluation
Adds a new method for evaluating expressions using code that is generated though Scala reflection. This functionality is configured by the SQLConf option `spark.sql.codegen` and is currently turned off by default. Evaluation can be done in several specialized ways: - *Projection* - Given an input row, produce a new row from a set of expressions that define each column in terms of the input row. This can either produce a new Row object or perform the projection in-place on an existing Row (MutableProjection). - *Ordering* - Compares two rows based on a list of `SortOrder` expressions - *Condition* - Returns `true` or `false` given an input row. For each of the above operations there is both a Generated and Interpreted version. When generation for a given expression type is undefined, the code generator falls back on calling the `eval` function of the expression class. Even without custom code, there is still a potential speed up, as loops are unrolled and code can still be inlined by JIT. This PR also contains a new type of Aggregation operator, `GeneratedAggregate`, that performs aggregation by using generated `Projection` code. Currently the required expression rewriting only works for simple aggregations like `SUM` and `COUNT`. This functionality will be extended in a future PR. This PR also performs several clean ups that simplified the implementation: - The notion of `Binding` all expressions in a tree automatically before query execution has been removed. Instead it is the responsibly of an operator to provide the input schema when creating one of the specialized evaluators defined above. In cases when the standard eval method is going to be called, binding can still be done manually using `BindReferences`. There are a few reasons for this change: First, there were many operators where it just didn't work before. For example, operators with more than one child, and operators like aggregation that do significant rewriting of the expression. Second, the semantics of equality with `BoundReferences` are broken. Specifically, we have had a few bugs where partitioning breaks because of the binding. - A copy of the current `SQLContext` is automatically propagated to all `SparkPlan` nodes by the query planner. Before this was done ad-hoc for the nodes that needed this. However, this required a lot of boilerplate as one had to always remember to make it `transient` and also had to modify the `otherCopyArgs`. Author: Michael Armbrust <michael@databricks.com> Closes #993 from marmbrus/newCodeGen and squashes the following commits: 96ef82c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen f34122d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into newCodeGen 67b1c48 [Michael Armbrust] Use conf variable in SQLConf object 4bdc42c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 41a40c9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen de22aac [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen fed3634 [Michael Armbrust] Inspectors are not serializable. ef8d42b [Michael Armbrust] comments 533fdfd [Michael Armbrust] More logging of expression rewriting for GeneratedAggregate. 3cd773e [Michael Armbrust] Allow codegen for Generate. 64b2ee1 [Michael Armbrust] Implement copy 3587460 [Michael Armbrust] Drop unused string builder function. 9cce346 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 1a61293 [Michael Armbrust] Address review comments. 0672e8a [Michael Armbrust] Address comments. 1ec2d6e [Michael Armbrust] Address comments 033abc6 [Michael Armbrust] off by default 4771fab [Michael Armbrust] Docs, more test coverage. d30fee2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen d2ad5c5 [Michael Armbrust] Refactor putting SQLContext into SparkPlan. Fix ordering, other test cases. be2cd6b [Michael Armbrust] WIP: Remove old method for reference binding, more work on configuration. bc88ecd [Michael Armbrust] Style 6cc97ca [Michael Armbrust] Merge remote-tracking branch 'origin/master' into newCodeGen 4220f1e [Michael Armbrust] Better config, docs, etc. ca6cc6b [Michael Armbrust] WIP 9d67d85 [Michael Armbrust] Fix hive planner fc522d5 [Michael Armbrust] Hook generated aggregation in to the planner. e742640 [Michael Armbrust] Remove unneeded changes and code. 675e679 [Michael Armbrust] Upgrade paradise. 0093376 [Michael Armbrust] Comment / indenting cleanup. d81f998 [Michael Armbrust] include schema for binding. 0e889e8 [Michael Armbrust] Use typeOf instead tq f623ffd [Michael Armbrust] Quiet logging from test suite. efad14f [Michael Armbrust] Remove some half finished functions. 92e74a4 [Michael Armbrust] add overrides a2b5408 [Michael Armbrust] WIP: Code generation with scala reflection.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/pom.xml9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala468
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala98
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala219
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala80
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala200
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala138
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala6
-rw-r--r--sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d1
-rw-r--r--sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f21
-rw-r--r--sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e731
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala11
51 files changed, 1871 insertions, 294 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 531bfddbf2..54fa96baa1 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -38,9 +38,18 @@
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
+ <artifactId>scala-compiler</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
<dependency>
+ <groupId>org.scalamacros</groupId>
+ <artifactId>quasiquotes_${scala.binary.version}</artifactId>
+ <version>${scala.macros.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5c8c810d91..f44521d638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -202,7 +202,7 @@ package object dsl {
// Protobuf terminology
def required = a.withNullability(false)
- def at(ordinal: Int) = BoundReference(ordinal, a)
+ def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 9ce1f01056..a3ebec8082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.trees
+
import org.apache.spark.sql.Logging
/**
@@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging
* to be retrieved more efficiently. However, since operations like column pruning can change
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
-case class BoundReference(ordinal: Int, baseReference: Attribute)
- extends Attribute with trees.LeafNode[Expression] {
+case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
+ extends Expression with trees.LeafNode[Expression] {
type EvaluatedType = Any
- override def nullable = baseReference.nullable
- override def dataType = baseReference.dataType
- override def exprId = baseReference.exprId
- override def qualifiers = baseReference.qualifiers
- override def name = baseReference.name
+ override def references = Set.empty
- override def newInstance = BoundReference(ordinal, baseReference.newInstance)
- override def withNullability(newNullability: Boolean) =
- BoundReference(ordinal, baseReference.withNullability(newNullability))
- override def withQualifiers(newQualifiers: Seq[String]) =
- BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
-
- override def toString = s"$baseReference:$ordinal"
+ override def toString = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
}
-/**
- * Used to denote operators that do their own binding of attributes internally.
- */
-trait NoBind { self: trees.TreeNode[_] => }
-
-class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
- import BindReferences._
-
- def apply(plan: TreeNode): TreeNode = {
- plan.transform {
- case n: NoBind => n.asInstanceOf[TreeNode]
- case leafNode if leafNode.children.isEmpty => leafNode
- case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
- bindReference(e, unaryNode.children.head.output)
- }
- }
- }
-}
-
object BindReferences extends Logging {
def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
- // TODO: This fallback is required because some operators (such as ScriptTransform)
- // produce new attributes that can't be bound. Likely the right thing to do is remove
- // this rule and require all operators to explicitly bind to the input schema that
- // they specify.
- logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
- a
+ sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
} else {
- BoundReference(ordinal, a)
+ BoundReference(ordinal, a.dataType, a.nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 2c71d2c7b3..8fc5896974 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
+
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
+ * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-class Projection(expressions: Seq[Expression]) extends (Row => Row) {
+class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
@@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
}
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of th
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
- *
- * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
- * each time an input row is added. This significantly reduces the cost of calculating the
- * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
- * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
- * and hold on to the returned [[Row]] before calling `next()`.
+ * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
+ * expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
+case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
private[this] val exprArray = expressions.toArray
- private[this] val mutableRow = new GenericMutableRow(exprArray.size)
+ private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size)
def currentValue: Row = mutableRow
- def apply(input: Row): Row = {
+ override def target(row: MutableRow): MutableProjection = {
+ mutableRow = row
+ this
+ }
+
+ override def apply(input: Row): Row = {
var i = 0
while (i < exprArray.length) {
mutableRow(i) = exprArray(i).eval(input)
@@ -76,6 +77,12 @@ class JoinedRow extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 74ae723686..7470cb861b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -88,15 +88,6 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
-
- /**
- * Experimental
- *
- * Returns a mutable string builder for the specified column. A given row should return the
- * result of any mutations made to the returned buffer next time getString is called for the same
- * column.
- */
- def getStringBuilder(ordinal: Int): StringBuilder
}
/**
@@ -180,6 +171,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
values(i).asInstanceOf[String]
}
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode(): Int = {
+ var result: Int = 37
+
+ var i = 0
+ while (i < values.length) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ apply(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+
def copy() = this
}
@@ -187,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)
- def getStringBuilder(ordinal: Int): StringBuilder = ???
-
override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 5e089f7618..acddf5e9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
override def eval(input: Row): Any = {
children.size match {
+ case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
new file mode 100644
index 0000000000..5b398695bf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.google.common.cache.{CacheLoader, CacheBuilder}
+
+import scala.language.existentials
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A base class for generators of byte code to perform expression evaluation. Includes a set of
+ * helpers for referring to Catalyst types and building trees that perform evaluation of individual
+ * expressions.
+ */
+abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ import scala.tools.reflect.ToolBox
+
+ protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
+
+ protected val rowType = typeOf[Row]
+ protected val mutableRowType = typeOf[MutableRow]
+ protected val genericRowType = typeOf[GenericRow]
+ protected val genericMutableRowType = typeOf[GenericMutableRow]
+
+ protected val projectionType = typeOf[Projection]
+ protected val mutableProjectionType = typeOf[MutableProjection]
+
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ private val javaSeparator = "$"
+
+ /**
+ * Generates a class for a given input expression. Called when there is not cached code
+ * already available.
+ */
+ protected def create(in: InType): OutType
+
+ /**
+ * Canonicalizes an input expression. Used to avoid double caching expressions that differ only
+ * cosmetically.
+ */
+ protected def canonicalize(in: InType): InType
+
+ /** Binds an input expression to a given input schema */
+ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
+
+ /**
+ * A cache of generated classes.
+ *
+ * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
+ * fundamental difference is that a ConcurrentMap persists all elements that are added to it until
+ * they are explicitly removed. A Cache on the other hand is generally configured to evict entries
+ * automatically, in order to constrain its memory footprint
+ */
+ protected val cache = CacheBuilder.newBuilder()
+ .maximumSize(1000)
+ .build(
+ new CacheLoader[InType, OutType]() {
+ override def load(in: InType): OutType = globalLock.synchronized {
+ create(in)
+ }
+ })
+
+ /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
+ def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType =
+ apply(bind(expressions, inputSchema))
+
+ /** Generates the requested evaluator given already bound expression(s). */
+ def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))
+
+ /**
+ * Returns a term name that is unique within this instance of a `CodeGenerator`.
+ *
+ * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
+ * function.)
+ */
+ protected def freshName(prefix: String): TermName = {
+ newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
+ }
+
+ /**
+ * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
+ *
+ * @param code The sequence of statements required to evaluate the expression.
+ * @param nullTerm A term that holds a boolean value representing whether the expression evaluated
+ * to null.
+ * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
+ * valid if `nullTerm` is set to `false`.
+ * @param objectTerm A possibly boxed version of the result of evaluating this expression.
+ */
+ protected case class EvaluatedExpression(
+ code: Seq[Tree],
+ nullTerm: TermName,
+ primitiveTerm: TermName,
+ objectTerm: TermName)
+
+ /**
+ * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
+ * can be used to determine the result of evaluating the expression on an input row.
+ */
+ def expressionEvaluator(e: Expression): EvaluatedExpression = {
+ val primitiveTerm = freshName("primitiveTerm")
+ val nullTerm = freshName("nullTerm")
+ val objectTerm = freshName("objectTerm")
+
+ implicit class Evaluate1(e: Expression) {
+ def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${f(eval.primitiveTerm)}
+ """.children
+ }
+ }
+
+ implicit class Evaluate2(expressions: (Expression, Expression)) {
+
+ /**
+ * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
+ * the same type. If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f a function from two primitive term names to a tree that evaluates them.
+ */
+ def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ evaluateAs(expressions._1.dataType)(f)
+
+ def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = {
+ // TODO: Right now some timestamp tests fail if we enforce this...
+ if (expressions._1.dataType != expressions._2.dataType) {
+ log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
+ }
+
+ val eval1 = expressionEvaluator(expressions._1)
+ val eval2 = expressionEvaluator(expressions._2)
+ val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
+ val $primitiveTerm: ${termForType(resultType)} =
+ if($nullTerm) {
+ ${defaultPrimitive(resultType)}
+ } else {
+ $resultCode.asInstanceOf[${termForType(resultType)}]
+ }
+ """.children : Seq[Tree]
+ }
+ }
+
+ val inputTuple = newTermName(s"i")
+
+ // TODO: Skip generation of null handling code when expression are not nullable.
+ val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ case b @ BoundReference(ordinal, dataType, nullable) =>
+ val nullValue = q"$inputTuple.isNullAt($ordinal)"
+ q"""
+ val $nullTerm: Boolean = $nullValue
+ val $primitiveTerm: ${termForType(dataType)} =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${getColumn(inputTuple, dataType, ordinal)}
+ """.children
+
+ case expressions.Literal(null, dataType) =>
+ q"""
+ val $nullTerm = true
+ val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}]
+ """.children
+
+ case expressions.Literal(value: Boolean, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: String, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Int, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Long, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case Cast(e @ BinaryType(), StringType) =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ """.children
+
+ case Cast(child @ NumericType(), IntegerType) =>
+ child.castOrNull(c => q"$c.toInt", IntegerType)
+
+ case Cast(child @ NumericType(), LongType) =>
+ child.castOrNull(c => q"$c.toLong", LongType)
+
+ case Cast(child @ NumericType(), DoubleType) =>
+ child.castOrNull(c => q"$c.toDouble", DoubleType)
+
+ case Cast(child @ NumericType(), FloatType) =>
+ child.castOrNull(c => q"$c.toFloat", IntegerType)
+
+ // Special handling required for timestamps in hive test cases since the toString function
+ // does not match the expected output.
+ case Cast(e, StringType) if e.dataType != TimestampType =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ ${eval.primitiveTerm}.toString
+ """.children
+
+ case EqualTo(e1, e2) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
+
+ /* TODO: Fix null semantics.
+ case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) =>
+ val eval = expressionEvaluator(e1)
+
+ val checks = list.map {
+ case expressions.Literal(v: String, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ case expressions.Literal(v: Int, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ }
+
+ val funcName = newTermName(s"isIn${curId.getAndIncrement()}")
+
+ q"""
+ def $funcName: Boolean = {
+ ..${eval.code}
+ if(${eval.nullTerm}) return false
+ ..$checks
+ return false
+ }
+ val $nullTerm = false
+ val $primitiveTerm = $funcName
+ """.children
+ */
+
+ case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
+ case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
+ case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
+ case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }
+
+ case And(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && !${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = false
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = true
+ }
+ """.children
+
+ case Or(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && ${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = true
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = false
+ }
+ """.children
+
+ case Not(child) =>
+ // Uh, bad function name...
+ child.castOrNull(c => q"!$c", BooleanType)
+
+ case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
+ case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
+ case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
+ case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+
+ case IsNotNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm}
+ """.children
+
+ case IsNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}
+ """.children
+
+ case c @ Coalesce(children) =>
+ q"""
+ var $nullTerm = true
+ var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
+ """.children ++
+ children.map { c =>
+ val eval = expressionEvaluator(c)
+ q"""
+ if($nullTerm) {
+ ..${eval.code}
+ if(!${eval.nullTerm}) {
+ $nullTerm = false
+ $primitiveTerm = ${eval.primitiveTerm}
+ }
+ }
+ """
+ }
+
+ case i @ expressions.If(condition, trueValue, falseValue) =>
+ val condEval = expressionEvaluator(condition)
+ val trueEval = expressionEvaluator(trueValue)
+ val falseEval = expressionEvaluator(falseValue)
+
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)}
+ ..${condEval.code}
+ if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
+ ..${trueEval.code}
+ $nullTerm = ${trueEval.nullTerm}
+ $primitiveTerm = ${trueEval.primitiveTerm}
+ } else {
+ ..${falseEval.code}
+ $nullTerm = ${falseEval.nullTerm}
+ $primitiveTerm = ${falseEval.primitiveTerm}
+ }
+ """.children
+ }
+
+ // If there was no match in the partial function above, we fall back on calling the interpreted
+ // expression evaluator.
+ val code: Seq[Tree] =
+ primitiveEvaluation.lift.apply(e).getOrElse {
+ log.debug(s"No rules to generate $e")
+ val tree = reify { e }
+ q"""
+ val $objectTerm = $tree.eval(i)
+ val $nullTerm = $objectTerm == null
+ val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
+ """.children
+ }
+
+ EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
+ }
+
+ protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
+ dataType match {
+ case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
+ case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
+ }
+ }
+
+ protected def setColumn(
+ destinationRow: TermName,
+ dataType: DataType,
+ ordinal: Int,
+ value: TermName) = {
+ dataType match {
+ case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case _ => q"$destinationRow.update($ordinal, $value)"
+ }
+ }
+
+ protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
+ protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+
+ protected def primitiveForType(dt: DataType) = dt match {
+ case IntegerType => "Int"
+ case LongType => "Long"
+ case ShortType => "Short"
+ case ByteType => "Byte"
+ case DoubleType => "Double"
+ case FloatType => "Float"
+ case BooleanType => "Boolean"
+ case StringType => "String"
+ }
+
+ protected def defaultPrimitive(dt: DataType) = dt match {
+ case BooleanType => ru.Literal(Constant(false))
+ case FloatType => ru.Literal(Constant(-1.0.toFloat))
+ case StringType => ru.Literal(Constant("<uninit>"))
+ case ShortType => ru.Literal(Constant(-1.toShort))
+ case LongType => ru.Literal(Constant(1L))
+ case ByteType => ru.Literal(Constant(-1.toByte))
+ case DoubleType => ru.Literal(Constant(-1.toDouble))
+ case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+ case IntegerType => ru.Literal(Constant(-1))
+ case _ => ru.Literal(Constant(null))
+ }
+
+ protected def termForType(dt: DataType) = dt match {
+ case n: NativeType => n.tag
+ case _ => typeTag[Any]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
new file mode 100644
index 0000000000..a419fd7ecb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[Row]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ val mutableRowName = newTermName("mutableRow")
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
+ val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
+ val evaluationCode = expressionEvaluator(e)
+
+ evaluationCode.code :+
+ q"""
+ if(${evaluationCode.nullTerm})
+ mutableRow.setNullAt($i)
+ else
+ ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)}
+ """
+ }
+
+ val code =
+ q"""
+ () => { new $mutableProjectionType {
+
+ private[this] var $mutableRowName: $mutableRowType =
+ new $genericMutableRowType(${expressions.size})
+
+ def target(row: $mutableRowType): $mutableProjectionType = {
+ $mutableRowName = row
+ this
+ }
+
+ /* Provide immutable access to the last projected row. */
+ def currentValue: $rowType = mutableRow
+
+ def apply(i: $rowType): $rowType = {
+ ..$projectionCode
+ mutableRow
+ }
+ } }
+ """
+
+ log.debug(s"code for ${expressions.mkString(",")}:\n$code")
+ toolBox.eval(code).asInstanceOf[() => MutableProjection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
new file mode 100644
index 0000000000..4211998f75
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.typesafe.scalalogging.slf4j.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.{StringType, NumericType}
+
+/**
+ * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
+ * [[Expression Expressions]].
+ */
+object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
+ in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder])
+
+ protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
+ val a = newTermName("a")
+ val b = newTermName("b")
+ val comparisons = ordering.zipWithIndex.map { case (order, i) =>
+ val evalA = expressionEvaluator(order.child)
+ val evalB = expressionEvaluator(order.child)
+
+ val compare = order.child.dataType match {
+ case _: NumericType =>
+ q"""
+ val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
+ if(comp != 0) {
+ return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
+ }
+ """
+ case StringType =>
+ if (order.direction == Ascending) {
+ q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
+ } else {
+ q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
+ }
+ }
+
+ q"""
+ i = $a
+ ..${evalA.code}
+ i = $b
+ ..${evalB.code}
+ if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ // Nothing
+ } else if (${evalA.nullTerm}) {
+ return ${if (order.direction == Ascending) q"-1" else q"1"}
+ } else if (${evalB.nullTerm}) {
+ return ${if (order.direction == Ascending) q"1" else q"-1"}
+ } else {
+ $compare
+ }
+ """
+ }
+
+ val q"class $orderingName extends $orderingType { ..$body }" = reify {
+ class SpecificOrdering extends Ordering[Row] {
+ val o = ordering
+ }
+ }.tree.children.head
+
+ val code = q"""
+ class $orderingName extends $orderingType {
+ ..$body
+ def compare(a: $rowType, b: $rowType): Int = {
+ var i: $rowType = null // Holds current row being evaluated.
+ ..$comparisons
+ return 0
+ }
+ }
+ new $orderingName()
+ """
+ logger.debug(s"Generated Ordering: $code")
+ toolBox.eval(code).asInstanceOf[Ordering[Row]]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
new file mode 100644
index 0000000000..2a0935c790
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]].
+ */
+object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in)
+
+ protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
+ BindReferences.bindReference(in, inputSchema)
+
+ protected def create(predicate: Expression): ((Row) => Boolean) = {
+ val cEval = expressionEvaluator(predicate)
+
+ val code =
+ q"""
+ (i: $rowType) => {
+ ..${cEval.code}
+ if (${cEval.nullTerm}) false else ${cEval.primitiveTerm}
+ }
+ """
+
+ log.debug(s"Generated predicate '$predicate':\n$code")
+ toolBox.eval(code).asInstanceOf[Row => Boolean]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
new file mode 100644
index 0000000000..77fa02c13d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+
+/**
+ * Generates bytecode that produces a new [[Row]] object based on a fixed set of input
+ * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom
+ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values.
+ */
+object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ // Make Mutablility optional...
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val tupleLength = ru.Literal(Constant(expressions.length))
+ val lengthDef = q"final val length = $tupleLength"
+
+ /* TODO: Configurable...
+ val nullFunctions =
+ q"""
+ private final val nullSet = new org.apache.spark.util.collection.BitSet(length)
+ final def setNullAt(i: Int) = nullSet.set(i)
+ final def isNullAt(i: Int) = nullSet.get(i)
+ """
+ */
+
+ val nullFunctions =
+ q"""
+ private[this] var nullBits = new Array[Boolean](${expressions.size})
+ final def setNullAt(i: Int) = { nullBits(i) = true }
+ final def isNullAt(i: Int) = nullBits(i)
+ """.children
+
+ val tupleElements = expressions.zipWithIndex.flatMap {
+ case (e, i) =>
+ val elementName = newTermName(s"c$i")
+ val evaluatedExpression = expressionEvaluator(e)
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _
+ {
+ ..${evaluatedExpression.code}
+ if(${evaluatedExpression.nullTerm})
+ setNullAt($iLit)
+ else
+ $elementName = ${evaluatedExpression.primitiveTerm}
+ }
+ """.children : Seq[Tree]
+ }
+
+ val iteratorFunction = {
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+ q"final def iterator = Iterator[Any](..$allColumns)"
+ }
+
+ val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
+ val applyFunction = {
+ val cases = (0 until expressions.size).map { i =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }"
+ }
+ q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }"
+ }
+
+ val updateFunction = {
+ val cases = expressions.zipWithIndex.map {case (e, i) =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ if(i == $ordinal) {
+ if(value == null) {
+ setNullAt(i)
+ } else {
+ $elementName = value.asInstanceOf[${termForType(e.dataType)}]
+ return
+ }
+ }"""
+ }
+ q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
+ }
+
+ val specificAccessorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) return $elementName" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val specificMutatorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) { $elementName = value; return }" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val hashValues = expressions.zipWithIndex.map { case (e,i) =>
+ val elementName = newTermName(s"c$i")
+ val nonNull = e.dataType match {
+ case BooleanType => q"if ($elementName) 0 else 1"
+ case ByteType | ShortType | IntegerType => q"$elementName.toInt"
+ case LongType => q"($elementName ^ ($elementName >>> 32)).toInt"
+ case FloatType => q"java.lang.Float.floatToIntBits($elementName)"
+ case DoubleType =>
+ q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }"
+ case _ => q"$elementName.hashCode"
+ }
+ q"if (isNullAt($i)) 0 else $nonNull"
+ }
+
+ val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree)
+
+ val hashCodeFunction =
+ q"""
+ override def hashCode(): Int = {
+ var result: Int = 37
+ ..$hashUpdates
+ result
+ }
+ """
+
+ val columnChecks = (0 until expressions.size).map { i =>
+ val elementName = newTermName(s"c$i")
+ q"if (this.$elementName != specificType.$elementName) return false"
+ }
+
+ val equalsFunction =
+ q"""
+ override def equals(other: Any): Boolean = other match {
+ case specificType: SpecificRow =>
+ ..$columnChecks
+ return true
+ case other => super.equals(other)
+ }
+ """
+
+ val copyFunction =
+ q"""
+ final def copy() = new $genericRowType(this.toArray)
+ """
+
+ val classBody =
+ nullFunctions ++ (
+ lengthDef +:
+ iteratorFunction +:
+ applyFunction +:
+ updateFunction +:
+ equalsFunction +:
+ hashCodeFunction +:
+ copyFunction +:
+ (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
+
+ val code = q"""
+ final class SpecificRow(i: $rowType) extends $mutableRowType {
+ ..$classBody
+ }
+
+ new $projectionType { def apply(r: $rowType) = new SpecificRow(r) }
+ """
+
+ log.debug(
+ s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}")
+ toolBox.eval(code).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
new file mode 100644
index 0000000000..80c7dfd376
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.rules
+import org.apache.spark.sql.catalyst.util
+
+/**
+ * A collection of generators that build custom bytecode at runtime for performing the evaluation
+ * of catalyst expression.
+ */
+package object codegen {
+
+ /**
+ * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala
+ * 2.10.
+ */
+ protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+ /** Canonicalizes an expression so those that differ only by names can reuse the same code. */
+ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
+ val batches =
+ Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil
+
+ object CleanExpressions extends rules.Rule[Expression] {
+ def apply(e: Expression): Expression = e transform {
+ case Alias(c, _) => c
+ }
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Dumps the bytecode from a class to the screen using javap.
+ */
+ @DeveloperApi
+ object DumpByteCode {
+ import scala.sys.process._
+ val dumpDirectory = util.getTempFilePath("sparkSqlByteCode")
+ dumpDirectory.mkdir()
+
+ def apply(obj: Any): Unit = {
+ val generatedClass = obj.getClass
+ val classLoader =
+ generatedClass
+ .getClassLoader
+ .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader]
+ val generatedBytes = classLoader.classBytes(generatedClass.getName)
+
+ val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName)
+ if (!packageDir.exists()) { packageDir.mkdir() }
+
+ val classFile =
+ new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class")
+
+ val outfile = new java.io.FileOutputStream(classFile)
+ outfile.write(generatedBytes)
+ outfile.close()
+
+ println(
+ s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index b6f2451b52..55d95991c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst
* ==Evaluation==
* The result of expressions can be evaluated using the `Expression.apply(Row)` method.
*/
-package object expressions
+package object expressions {
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ */
+ abstract class Projection extends (Row => Row)
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ *
+ * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
+ * each time an input row is added. This significantly reduces the cost of calculating the
+ * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
+ * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
+ * and hold on to the returned [[Row]] before calling `next()`.
+ */
+ abstract class MutableProjection extends Projection {
+ def currentValue: Row
+
+ /** Uses the given row to store the output of the projection. */
+ def target(row: MutableRow): MutableProjection
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 06b94a98d3..5976b0ddf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType
object InterpretedPredicate {
+ def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
+ apply(BindReferences.bindReference(expression, inputSchema))
+
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.eval(r).asInstanceOf[Boolean]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
new file mode 100644
index 0000000000..3b3e206055
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+package object catalyst {
+ /**
+ * A JVM-global lock that should be used to prevent thread safety issues when using things in
+ * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
+ * 2.10.* builds. See SI-6240 for more details.
+ */
+ protected[catalyst] object ScalaReflectionLock
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 026692abe0..418f8686bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -105,6 +105,77 @@ object PhysicalOperation extends PredicateHelper {
}
/**
+ * Matches a logical aggregation that can be performed on distributed data in two steps. The first
+ * operates on the data in each partition performing partial aggregation for each group. The second
+ * occurs after the shuffle and completes the aggregation.
+ *
+ * This pattern will only match if all aggregate expressions can be computed partially and will
+ * return the rewritten aggregation expressions for both phases.
+ *
+ * The returned values for this match are as follows:
+ * - Grouping attributes for the final aggregation.
+ * - Aggregates for the final aggregation.
+ * - Grouping expressions for the partial aggregation.
+ * - Partial aggregate expressions.
+ * - Input to the aggregation.
+ */
+object PartialAggregation {
+ type ReturnType =
+ (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ // Collect all aggregate expressions.
+ val allAggregates =
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ // Collect all aggregate expressions that can be computed partially.
+ val partialAggregates =
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+
+ // Only do partial aggregation if supported by all aggregate expressions.
+ if (allAggregates.size == partialAggregates.size) {
+ // Create a map of expressions to their partial evaluations for all aggregate expressions.
+ val partialEvaluations: Map[Long, SplitEvaluation] =
+ partialAggregates.map(a => (a.id, a.asPartial)).toMap
+
+ // We need to pass all grouping expressions though so the grouping can happen a second
+ // time. However some of them might be unnamed so we alias them allowing them to be
+ // referenced in the second aggregation.
+ val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
+ case n: NamedExpression => (n, n)
+ case other => (other, Alias(other, "PartialGroup")())
+ }.toMap
+
+ // Replace aggregations with a new expression that computes the result from the already
+ // computed partial evaluations and grouping values.
+ val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+ case e: Expression if partialEvaluations.contains(e.id) =>
+ partialEvaluations(e.id).finalEvaluation
+ case e: Expression if namedGroupingExpressions.contains(e) =>
+ namedGroupingExpressions(e).toAttribute
+ }).asInstanceOf[Seq[NamedExpression]]
+
+ val partialComputation =
+ (namedGroupingExpressions.values ++
+ partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+
+ val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
+
+ Some(
+ (namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child))
+ } else {
+ None
+ }
+ case _ => None
+ }
+}
+
+
+/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ac85f95b52..888cb08e95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -112,7 +112,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
self: Product =>
override lazy val statistics: Statistics =
- throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics")
+ throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
// Leaf nodes by definition cannot reference any input attributes.
override def references = Set.empty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index a357c6ffb8..481a5a4f21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -35,7 +35,7 @@ abstract class Command extends LeafNode {
*/
case class NativeCommand(cmd: String) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)()))
+ Seq(AttributeReference("result", StringType, nullable = false)())
}
/**
@@ -43,7 +43,7 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(1, AttributeReference("", StringType, nullable = false)()))
+ AttributeReference("", StringType, nullable = false)())
}
/**
@@ -52,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman
*/
case class ExplainCommand(plan: LogicalPlan) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)()))
+ Seq(AttributeReference("plan", StringType, nullable = false)())
}
/**
@@ -71,7 +71,7 @@ case class DescribeCommand(
isExtended: Boolean) extends Command {
override def output = Seq(
// Column names are based on Hive.
- BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()),
- BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()),
- BoundReference(2, AttributeReference("comment", StringType, nullable = false)()))
+ AttributeReference("col_name", StringType, nullable = false)(),
+ AttributeReference("data_type", StringType, nullable = false)(),
+ AttributeReference("comment", StringType, nullable = false)())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index e32adb76fe..e300bdbece 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -72,7 +72,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
}
iteration += 1
if (iteration > batch.strategy.maxIterations) {
- logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}")
+ // Only log if this is a rule that is supposed to run more than once.
+ if (iteration != 2) {
+ logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
+ }
continue = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index cd4b5e9c1b..71808f76d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -23,16 +23,13 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
import scala.util.parsing.combinator.RegexParsers
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
/**
- * A JVM-global lock that should be used to prevent thread safety issues when using things in
- * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
- * 2.10.* builds. See SI-6240 for more details.
+ * Utility functions for working with DataTypes.
*/
-protected[catalyst] object ScalaReflectionLock
-
object DataType extends RegexParsers {
protected lazy val primitiveType: Parser[DataType] =
"StringType" ^^^ StringType |
@@ -99,6 +96,13 @@ abstract class DataType {
case object NullType extends DataType
+object NativeType {
+ def all = Seq(
+ IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+
+ def unapply(dt: DataType): Boolean = all.contains(dt)
+}
+
trait PrimitiveType extends DataType {
override def isPrimitive = true
}
@@ -149,6 +153,10 @@ abstract class NumericType extends NativeType with PrimitiveType {
val numeric: Numeric[JvmType]
}
+object NumericType {
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
+}
+
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
object IntegralType {
def unapply(a: Expression): Boolean = a match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 58f8c341e6..999c9fff38 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
class ExpressionEvaluationSuite extends FunSuite {
test("literals") {
- assert((Literal(1) + Literal(1)).eval(null) === 2)
+ checkEvaluation(Literal(1), 1)
+ checkEvaluation(Literal(true), true)
+ checkEvaluation(Literal(0L), 0L)
+ checkEvaluation(Literal("test"), "test")
+ checkEvaluation(Literal(1) + Literal(1), 2)
}
/**
@@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite {
test("3VL Not") {
notTrueTable.foreach {
case (v, answer) =>
- val expr = ! Literal(v, BooleanType)
- val result = expr.eval(null)
- if (result != answer)
- fail(s"$expr should not evaluate to $result, expected: $answer") }
+ checkEvaluation(!Literal(v, BooleanType), answer)
+ }
}
booleanLogicTest("AND", _ && _,
@@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite {
}
}
+ test("IN") {
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
@@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
- checkEvaluation("23" cast DoubleType, 23)
+ checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
- checkEvaluation("23" cast FloatType, 23)
- checkEvaluation("23" cast DecimalType, 23)
- checkEvaluation("23" cast ByteType, 23)
- checkEvaluation("23" cast ShortType, 23)
+ checkEvaluation("23" cast FloatType, 23f)
+ checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast ByteType, 23.toByte)
+ checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)
- checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+ checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
- checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
- checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
- checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+ checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
+ checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
+ checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
@@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite {
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(null, IntegerType)), null, row)
- checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
+ checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
val typeS_notNullable = StructType(
@@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite {
:: StructField("b", StringType, nullable = false) :: Nil
)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS)()), "a").nullable === true)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false)
+ assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
+ assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
new file mode 100644
index 0000000000..245a2e1480
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use code generation for evaluation.
+ */
+class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ val plan = try {
+ GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)()
+ } catch {
+ case e: Throwable =>
+ val evaluated = GenerateProjection.expressionEvaluator(expression)
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow).apply(0)
+ if(actual != expected) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+
+
+ test("multithreaded eval") {
+ import scala.concurrent._
+ import ExecutionContext.Implicits.global
+ import scala.concurrent.duration._
+
+ val futures = (1 to 20).map { _ =>
+ future {
+ GeneratePredicate(EqualTo(Literal(1), Literal(1)))
+ GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil)
+ }
+ }
+
+ futures.foreach(Await.result(_, 10.seconds))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
new file mode 100644
index 0000000000..887aabb1d5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use generated code on mutable rows.
+ */
+class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
+
+ val plan = try {
+ GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
+ } catch {
+ case e: Throwable =>
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow)
+ val expectedRow = new GenericRow(Array[Any](expected))
+ if (actual.hashCode() != expectedRow.hashCode()) {
+ fail(
+ s"""
+ |Mismatched hashCodes for values: $actual, $expectedRow
+ |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
+ |${evaluated.code.mkString("\n")}
+ """.stripMargin)
+ }
+ if (actual != expectedRow) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 4896f1b955..e2ae0d25db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("Combine Limit", FixedPoint(2),
+ Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
- Batch("Constant Folding", FixedPoint(3),
+ Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
BooleanSimplification) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 5d85a0fd4e..2d407077be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -24,8 +24,11 @@ import scala.collection.JavaConverters._
object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
- val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
+ val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
+ val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
+ val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
+ val CODEGEN_ENABLED = "spark.sql.codegen"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -57,6 +60,18 @@ trait SQLConf {
private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
/**
+ * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
+ * that evaluates expressions found in queries. In general this custom code runs much faster
+ * than interpreted evaluation, but there are significant start-up costs due to compilation.
+ * As a result codegen is only benificial when queries run for a long time, or when the same
+ * expressions are used multiple times.
+ *
+ * Defaults to false as this feature is currently experimental.
+ */
+ private[spark] def codegenEnabled: Boolean =
+ if (get(CODEGEN_ENABLED, "false") == "true") true else false
+
+ /**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
* effectively disables auto conversion.
@@ -111,5 +126,5 @@ trait SQLConf {
private[spark] def clear() {
settings.clear()
}
-
}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c2bdef7323..e4b6810180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration)))
+ new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
* Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]].
@@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
conf: Configuration = new Configuration()): SchemaRDD = {
new SchemaRDD(
this,
- ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf))
+ ParquetRelation.createEmpty(
+ path, ScalaReflection.attributesFor[A], allowExisting, conf, this))
}
/**
@@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
val sqlContext: SQLContext = self
+ def codegenEnabled = self.codegenEnabled
+
def numPartitions = self.numShufflePartitions
val strategies: Seq[Strategy] =
CommandStrategy(self) ::
TakeOrdered ::
- PartialAggregation ::
+ HashAggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
@@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1)
/**
- * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
- * inserting shuffle operations as needed.
+ * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
*/
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange(self)) ::
- Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil
+ Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
/**
+ * :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
* access to the intermediate phases of query execution for developers.
*/
+ @DeveloperApi
protected abstract class QueryExecution {
def logical: LogicalPlan
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
- lazy val sparkPlan = planner(optimizedPlan).next()
+ lazy val sparkPlan = {
+ SparkPlan.currentContext.set(self)
+ planner(optimizedPlan).next()
+ }
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
@@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
|${stringOrError(optimizedPlan)}
|== Physical Plan ==
|${stringOrError(executedPlan)}
+ |Code Generation: ${executedPlan.codegenEnabled}
+ |== RDD ==
+ |${stringOrError(toRdd.toDebugString)}
""".stripMargin.trim
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 806097c917..85726bae54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
conf: Configuration = new Configuration()): JavaSchemaRDD = {
new JavaSchemaRDD(
sqlContext,
- ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf))
+ ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext))
}
/**
@@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
def parquetFile(path: String): JavaSchemaRDD =
new JavaSchemaRDD(
sqlContext,
- ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration)))
+ ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext))
/**
* Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index c1ced8bfa4..463a1d32d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -42,8 +42,8 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)(@transient sqlContext: SQLContext)
- extends UnaryNode with NoBind {
+ child: SparkPlan)
+ extends UnaryNode {
override def requiredChildDistribution =
if (partial) {
@@ -56,8 +56,6 @@ case class Aggregate(
}
}
- override def otherCopyArgs = sqlContext :: Nil
-
// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
private[this] val childOutput = child.output
@@ -138,7 +136,7 @@ case class Aggregate(
i += 1
}
}
- val resultProjection = new Projection(resultExpressions, computedSchema)
+ val resultProjection = new InterpretedProjection(resultExpressions, computedSchema)
val aggregateResults = new GenericMutableRow(computedAggregates.length)
var i = 0
@@ -152,7 +150,7 @@ case class Aggregate(
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
- val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
+ val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput)
var currentRow: Row = null
while (iter.hasNext) {
@@ -175,7 +173,8 @@ case class Aggregate(
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
- new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
+ new InterpretedMutableProjection(
+ resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow
override final def hasNext: Boolean = hashTableIter.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 00010ef6e7..392a7f3be3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
@@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind {
+case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
override def outputPartitioning = newPartitioning
@@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
val rdd = child.execute().mapPartitions { iter =>
- val hashExpressions = new MutableProjection(expressions, child.output)
+ @transient val hashExpressions =
+ newMutableProjection(expressions, child.output)()
+
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 47b3d00262..c386fd121c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -47,23 +47,26 @@ case class Generate(
}
}
- override def output =
+ // This must be a val since the generator output expr ids are not preserved by serialization.
+ override val output =
if (join) child.output ++ generatorOutput else generatorOutput
+ val boundGenerator = BindReferences.bindReference(generator, child.output)
+
override def execute() = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
- new Projection(child.output ++ nullValues, child.output)
+ newProjection(child.output ++ nullValues, child.output)
val joinProjection =
- new Projection(child.output ++ generator.output, child.output ++ generator.output)
+ newProjection(child.output ++ generator.output, child.output ++ generator.output)
val joinedRow = new JoinedRow
iter.flatMap {row =>
- val outputRows = generator.eval(row)
+ val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
@@ -72,7 +75,7 @@ case class Generate(
}
}
} else {
- child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
+ child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
new file mode 100644
index 0000000000..4a26934c49
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types._
+
+case class AggregateEvaluation(
+ schema: Seq[Attribute],
+ initialValues: Seq[Expression],
+ update: Seq[Expression],
+ result: Expression)
+
+/**
+ * :: DeveloperApi ::
+ * Alternate version of aggregation that leverages projection and thus code generation.
+ * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
+ * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
+ *
+ * @param partial if true then aggregation is done partially on local data without shuffling to
+ * ensure all values where `groupingExpressions` are equal are present.
+ * @param groupingExpressions expressions that are evaluated to determine grouping.
+ * @param aggregateExpressions expressions that are computed for each group.
+ * @param child the input data source.
+ */
+@DeveloperApi
+case class GeneratedAggregate(
+ partial: Boolean,
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def requiredChildDistribution =
+ if (partial) {
+ UnspecifiedDistribution :: Nil
+ } else {
+ if (groupingExpressions == Nil) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def output = aggregateExpressions.map(_.toAttribute)
+
+ override def execute() = {
+ val aggregatesToCompute = aggregateExpressions.flatMap { a =>
+ a.collect { case agg: AggregateExpression => agg}
+ }
+
+ val computeFunctions = aggregatesToCompute.map {
+ case c @ Count(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val initialValue = Literal(0L)
+ val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val result = currentCount
+
+ AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case Sum(expr) =>
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialValue = Cast(Literal(0L), expr.dataType)
+
+ // Coalasce avoids double calculation...
+ // but really, common sub expression elimination would be better....
+ val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+ val result = currentSum
+
+ AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case a @ Average(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialCount = Literal(0L)
+ val initialSum = Cast(Literal(0L), expr.dataType)
+ val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+
+ val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+
+ AggregateEvaluation(
+ currentCount :: currentSum :: Nil,
+ initialCount :: initialSum :: Nil,
+ updateCount :: updateSum :: Nil,
+ result
+ )
+ }
+
+ val computationSchema = computeFunctions.flatMap(_.schema)
+
+ val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
+ case (agg, func) => agg.id -> func.result
+ }.toMap
+
+ val namedGroups = groupingExpressions.zipWithIndex.map {
+ case (ne: NamedExpression, _) => (ne, ne)
+ case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+ }
+
+ val groupMap: Map[Expression, Attribute] =
+ namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
+
+ // The set of expressions that produce the final output given the aggregation buffer and the
+ // grouping expressions.
+ val resultExpressions = aggregateExpressions.map(_.transform {
+ case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+ case e: Expression if groupMap.contains(e) => groupMap(e)
+ })
+
+ child.execute().mapPartitions { iter =>
+ // Builds a new custom class for holding the results of aggregation for a group.
+ val initialValues = computeFunctions.flatMap(_.initialValues)
+ val newAggregationBuffer = newProjection(initialValues, child.output)
+ log.info(s"Initial values: ${initialValues.mkString(",")}")
+
+ // A projection that computes the group given an input tuple.
+ val groupProjection = newProjection(groupingExpressions, child.output)
+ log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
+
+ // A projection that is used to update the aggregate values for a group given a new tuple.
+ // This projection should be targeted at the current values for the group and then applied
+ // to a joined row of the current values with the new input row.
+ val updateExpressions = computeFunctions.flatMap(_.update)
+ val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
+ val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
+ log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
+
+ // A projection that produces the final result, given a computation.
+ val resultProjectionBuilder =
+ newMutableProjection(
+ resultExpressions,
+ (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+ log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
+
+ val joinedRow = new JoinedRow
+
+ if (groupingExpressions.isEmpty) {
+ // TODO: Codegening anything other than the updateProjection is probably over kill.
+ val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ var currentRow: Row = null
+ updateProjection.target(buffer)
+
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ updateProjection(joinedRow(buffer, currentRow))
+ }
+
+ val resultProjection = resultProjectionBuilder()
+ Iterator(resultProjection(buffer))
+ } else {
+ val buffers = new java.util.HashMap[Row, MutableRow]()
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ val currentGroup = groupProjection(currentRow)
+ var currentBuffer = buffers.get(currentGroup)
+ if (currentBuffer == null) {
+ currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ buffers.put(currentGroup, currentBuffer)
+ }
+ // Target the projection at the current aggregation buffer and then project the updated
+ // values.
+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
+ }
+
+ new Iterator[Row] {
+ private[this] val resultIterator = buffers.entrySet.iterator()
+ private[this] val resultProjection = resultProjectionBuilder()
+
+ def hasNext = resultIterator.hasNext
+
+ def next() = {
+ val currentGroup = resultIterator.next()
+ resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 77c874d031..21cbbc9772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -18,22 +18,55 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Logging, Row, SQLContext}
+
+
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
+
+object SparkPlan {
+ protected[sql] val currentContext = new ThreadLocal[SQLContext]()
+}
+
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
+abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
self: Product =>
+ /**
+ * A handle to the SQL Context that was used to create this plan. Since many operators need
+ * access to the sqlContext for RDD operations or configuration this field is automatically
+ * populated by the query planning infrastructure.
+ */
+ @transient
+ protected val sqlContext = SparkPlan.currentContext.get()
+
+ protected def sparkContext = sqlContext.sparkContext
+
+ // sqlContext will be null when we are being deserialized on the slaves. In this instance
+ // the value of codegenEnabled will be set by the desserializer after the constructor has run.
+ val codegenEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.codegenEnabled
+ } else {
+ false
+ }
+
+ /** Overridden make copy also propogates sqlContext to copied plan. */
+ override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ SparkPlan.currentContext.set(sqlContext)
+ super.makeCopy(newArgs)
+ }
+
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
@@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
*/
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
- protected def buildRow(values: Seq[Any]): Row =
- new GenericRow(values.toArray)
+ protected def newProjection(
+ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
+ log.debug(
+ s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if (codegenEnabled) {
+ GenerateProjection(expressions, inputSchema)
+ } else {
+ new InterpretedProjection(expressions, inputSchema)
+ }
+ }
+
+ protected def newMutableProjection(
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): () => MutableProjection = {
+ log.debug(
+ s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if(codegenEnabled) {
+ GenerateMutableProjection(expressions, inputSchema)
+ } else {
+ () => new InterpretedMutableProjection(expressions, inputSchema)
+ }
+ }
+
+
+ protected def newPredicate(
+ expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
+ if (codegenEnabled) {
+ GeneratePredicate(expression, inputSchema)
+ } else {
+ InterpretedPredicate(expression, inputSchema)
+ }
+ }
+
+ protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
+ if (codegenEnabled) {
+ GenerateOrdering(order, inputSchema)
+ } else {
+ new RowOrdering(order, inputSchema)
+ }
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 404d48ae05..5f1fe99f75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.util.Try
-
import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
@@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
@@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
*/
object HashJoin extends Strategy with PredicateHelper {
+
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition: Option[Expression],
side: BuildSide) = {
val broadcastHashJoin = execution.BroadcastHashJoin(
- leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
+ leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
- if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) {
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
BuildRight
} else {
BuildLeft
@@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object PartialAggregation extends Strategy {
+ object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
- // Collect all aggregate expressions.
- val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
- // Collect all aggregate expressions that can be computed partially.
- val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
-
- // Only do partial aggregation if supported by all aggregate expressions.
- if (allAggregates.size == partialAggregates.size) {
- // Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[Long, SplitEvaluation] =
- partialAggregates.map(a => (a.id, a.asPartial)).toMap
-
- // We need to pass all grouping expressions though so the grouping can happen a second
- // time. However some of them might be unnamed so we alias them allowing them to be
- // referenced in the second aggregation.
- val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }.toMap
+ // Aggregations that can be performed in two phases, before and after the shuffle.
- // Replace aggregations with a new expression that computes the result from the already
- // computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
- case e: Expression if partialEvaluations.contains(e.id) =>
- partialEvaluations(e.id).finalEvaluation
- case e: Expression if namedGroupingExpressions.contains(e) =>
- namedGroupingExpressions(e).toAttribute
- }).asInstanceOf[Seq[NamedExpression]]
-
- val partialComputation =
- (namedGroupingExpressions.values ++
- partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
-
- // Construct two phased aggregation.
- execution.Aggregate(
+ // Cases where all aggregates can be codegened.
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child)
+ if canBeCodeGened(
+ allAggregates(partialComputation) ++
+ allAggregates(rewrittenAggregateExpressions)) &&
+ codegenEnabled =>
+ execution.GeneratedAggregate(
partial = false,
- namedGroupingExpressions.values.map(_.toAttribute).toSeq,
+ namedGroupingAttributes,
rewrittenAggregateExpressions,
- execution.Aggregate(
+ execution.GeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
- planLater(child))(sqlContext))(sqlContext) :: Nil
- } else {
- Nil
- }
+ planLater(child))) :: Nil
+
+ // Cases where some aggregate can not be codegened
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child) =>
+ execution.Aggregate(
+ partial = false,
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ execution.Aggregate(
+ partial = true,
+ groupingExpressions,
+ partialComputation,
+ planLater(child))) :: Nil
+
case _ => Nil
}
+
+ def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
+ case _: Sum | _: Count => false
+ case _ => true
+ }
+
+ def allAggregates(exprs: Seq[Expression]) =
+ exprs.flatMap(_.collect { case a: AggregateExpression => a })
}
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), joinType, condition) :: Nil
case _ => Nil
}
}
@@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
- def convertToCatalyst(a: Any): Any = a match {
- case s: Seq[Any] => s.map(convertToCatalyst)
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case other => other
- }
-
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
+ execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
}
@@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// TODO: need to support writing to other types of files. Unify the below code paths.
case logical.WriteToFile(path, child) =>
val relation =
- ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
+ ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
- InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
+ InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
- InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
+ InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
- ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
+ ParquetTableScan(_, relation, filters)) :: Nil
case _ => Nil
}
@@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
+ execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
- val dataAsRdd =
- sparkContext.parallelize(data.map(r =>
- new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
- execution.ExistingRdd(output, dataAsRdd) :: Nil
+ ExistingRdd(
+ output,
+ ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child))(sqlContext) :: Nil
+ execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
- execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
- case logical.Except(left,right) =>
- execution.Except(planLater(left),planLater(right)) :: Nil
+ execution.Union(unionChildren.map(planLater)) :: Nil
+ case logical.Except(left, right) =>
+ execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 966d8f95fc..174eda8f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
override def output = projectList.map(_.toAttribute)
- override def execute() = child.execute().mapPartitions { iter =>
- @transient val reusableProjection = new MutableProjection(projectList)
- iter.map(reusableProjection)
+ @transient lazy val buildProjection = newMutableProjection(projectList, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ val resuableProjection = buildProjection()
+ iter.map(resuableProjection)
}
}
@@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output = child.output
- override def execute() = child.execute().mapPartitions { iter =>
- iter.filter(condition.eval(_).asInstanceOf[Boolean])
+ @transient lazy val conditionEvaluator = newPredicate(condition, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ iter.filter(conditionEvaluator)
}
}
@@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
+case class Union(children: Seq[SparkPlan]) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output = children.head.output
- override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
-
- override def otherCopyArgs = sqlContext :: Nil
+ override def execute() = sparkContext.union(children.map(_.execute()))
}
/**
@@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
* repartition all the data to a single partition to compute the global limit.
*/
@DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
+case class Limit(limit: Int, child: SparkPlan)
extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = child.output
/**
@@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sqlContext: SQLContext) extends UnaryNode {
- override def otherCopyArgs = sqlContext :: Nil
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
override def output = child.output
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
+ val ordering = new RowOrdering(sortOrder, child.output)
+ // TODO: Is this copying for no reason?
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
+ override def execute() = sparkContext.makeRDD(executeCollect(), 1)
}
/**
@@ -189,15 +187,13 @@ case class Sort(
override def requiredChildDistribution =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
override def execute() = attachTree(this, "sort") {
- // TODO: Optimize sorting operation?
child.execute()
- .mapPartitions(
- iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
- preservesPartitioning = true)
+ .mapPartitions( { iterator =>
+ val ordering = newOrdering(sortOrder, child.output)
+ iterator.map(_.copy()).toArray.sorted(ordering).iterator
+ }, preservesPartitioning = true)
}
override def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index c6fbd6d2f6..5ef46c32d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -41,13 +41,13 @@ package object debug {
*/
@DeveloperApi
implicit class DebugQuery(query: SchemaRDD) {
- def debug(implicit sc: SparkContext): Unit = {
+ def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[Long]()
val debugPlan = plan transform {
case s: SparkPlan if !visited.contains(s.id) =>
visited += s.id
- DebugNode(sc, s)
+ DebugNode(s)
}
println(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
@@ -57,9 +57,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(
- @transient sparkContext: SparkContext,
- child: SparkPlan) extends UnaryNode {
+ private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
def references = Set.empty
def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 7d1f11caae..2750ddbce8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide
trait HashJoin {
+ self: SparkPlan =>
+
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
@@ -56,9 +58,9 @@ trait HashJoin {
def output = left.output ++ right.output
- @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
- () => new MutableProjection(streamedKeys, streamedPlan.output)
+ newMutableProjection(streamedKeys, streamedPlan.output)
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.
@@ -217,9 +219,8 @@ case class BroadcastHashJoin(
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
- right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {
+ right: SparkPlan) extends BinaryNode with HashJoin {
- override def otherCopyArgs = sqlContext :: Nil
override def outputPartitioning: Partitioning = left.outputPartitioning
@@ -228,7 +229,7 @@ case class BroadcastHashJoin(
@transient
lazy val broadcastFuture = future {
- sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
+ sparkContext.broadcast(buildPlan.executeCollect())
}
def execute() = {
@@ -248,14 +249,11 @@ case class BroadcastHashJoin(
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
def output = left.output
/** The Streamed Relation */
@@ -271,7 +269,7 @@ case class LeftSemiJoinBNL(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
@@ -300,8 +298,14 @@ case class LeftSemiJoinBNL(
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
def output = left.output ++ right.output
- def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map {
- case (l: Row, r: Row) => buildRow(l ++ r)
+ def execute() = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.cartesian(rightResults).mapPartitions { iter =>
+ val joinedRow = new JoinedRow
+ iter.map(r => joinedRow(r._1, r._2))
+ }
}
}
@@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
@DeveloperApi
case class BroadcastNestedLoopJoin(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = {
joinType match {
case LeftOuter =>
@@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+ val rightNulls = new GenericMutableRow(right.output.size)
streamedIter.foreach { streamedRow =>
var i = 0
@@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin(
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matchedRows += buildRow(streamedRow ++ broadcastedRow)
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
matched = true
includedBroadcastTuples += i
}
@@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin(
}
if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
- matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null))
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
}
}
Iterator((matchedRows, includedBroadcastTuples))
@@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin(
streamedPlusMatches.map(_._2).reduce(_ ++ _)
}
+ val leftNulls = new GenericMutableRow(left.output.size)
val rightOuterMatches: Seq[Row] =
if (joinType == RightOuter || joinType == FullOuter) {
broadcastedRelation.value.zipWithIndex.filter {
case (row, i) => !allIncludedBroadcastTuples.contains(i)
}.map {
- // TODO: Use projection.
- case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row)
+ case (row, _) => new JoinedRow(leftNulls, row)
}
} else {
Vector()
}
// TODO: Breaks lineage.
- sqlContext.sparkContext.union(
- streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
+ sparkContext.union(
+ streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 8c7dbd5eb4..b3bae5db0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
*/
private[sql] case class ParquetRelation(
path: String,
- @transient conf: Option[Configuration] = None)
+ @transient conf: Option[Configuration],
+ @transient sqlContext: SQLContext)
extends LeafNode with MultiInstanceRelation {
self: Product =>
@@ -61,7 +62,7 @@ private[sql] case class ParquetRelation(
/** Attributes */
override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf)
- override def newInstance = ParquetRelation(path).asInstanceOf[this.type]
+ override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
// different instances of the same relation,
@@ -70,6 +71,9 @@ private[sql] case class ParquetRelation(
p.path == path && p.output == output
case _ => false
}
+
+ // TODO: Use data from the footers.
+ override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes)
}
private[sql] object ParquetRelation {
@@ -106,13 +110,14 @@ private[sql] object ParquetRelation {
*/
def create(pathString: String,
child: LogicalPlan,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
if (!child.resolved) {
throw new UnresolvedException[LogicalPlan](
child,
"Attempt to create Parquet table from unresolved child (when schema is not available)")
}
- createEmpty(pathString, child.output, false, conf)
+ createEmpty(pathString, child.output, false, conf, sqlContext)
}
/**
@@ -127,14 +132,15 @@ private[sql] object ParquetRelation {
def createEmpty(pathString: String,
attributes: Seq[Attribute],
allowExisting: Boolean,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
val path = checkPath(pathString, allowExisting, conf)
if (conf.get(ParquetOutputFormat.COMPRESSION) == null) {
conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name())
}
ParquetRelation.enableLogForwarding()
ParquetTypesConverter.writeMetaData(attributes, path, conf)
- new ParquetRelation(path.toString, Some(conf)) {
+ new ParquetRelation(path.toString, Some(conf), sqlContext) {
override val output = attributes
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index ea74320d06..912a9f002b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -55,8 +55,7 @@ case class ParquetTableScan(
// https://issues.apache.org/jira/browse/SPARK-1367
output: Seq[Attribute],
relation: ParquetRelation,
- columnPruningPred: Seq[Expression])(
- @transient val sqlContext: SQLContext)
+ columnPruningPred: Seq[Expression])
extends LeafNode {
override def execute(): RDD[Row] = {
@@ -99,8 +98,6 @@ case class ParquetTableScan(
.filter(_ != null) // Parquet's record filters may produce null values
}
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Applies a (candidate) projection.
*
@@ -110,7 +107,7 @@ case class ParquetTableScan(
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
val success = validateProjection(prunedAttributes)
if (success) {
- ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
+ ParquetTableScan(prunedAttributes, relation, columnPruningPred)
} else {
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
this
@@ -150,8 +147,7 @@ case class ParquetTableScan(
case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
- overwrite: Boolean = false)(
- @transient val sqlContext: SQLContext)
+ overwrite: Boolean = false)
extends UnaryNode with SparkHadoopMapReduceUtil {
/**
@@ -171,7 +167,7 @@ case class InsertIntoParquetTable(
val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
- logger.debug("Initializing MutableRowWriteSupport")
+ log.debug("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
classOf[org.apache.spark.sql.parquet.RowWriteSupport]
@@ -203,8 +199,6 @@ case class InsertIntoParquetTable(
override def output = child.output
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Stores the given Row RDD as a Hadoop file.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index d4599da711..837ea7695d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+import org.apache.spark.sql.test.TestSQLContext
import parquet.example.data.{GroupWriter, Group}
import parquet.example.data.simple.SimpleGroup
@@ -103,7 +104,7 @@ private[sql] object ParquetTestData {
val testDir = Utils.createTempDir()
val testFilterDir = Utils.createTempDir()
- lazy val testData = new ParquetRelation(testDir.toURI.toString)
+ lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext)
val testNestedSchema1 =
// based on blogpost example, source:
@@ -202,8 +203,10 @@ private[sql] object ParquetTestData {
val testNestedDir3 = Utils.createTempDir()
val testNestedDir4 = Utils.createTempDir()
- lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString)
- lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString)
+ lazy val testNestedData1 =
+ new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext)
+ lazy val testNestedData2 =
+ new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext)
def writeFile() = {
testDir.delete()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8e1e1971d9..1fd8d27b34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -45,6 +45,7 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution}
|== Exception ==
|$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 215618e852..76b1724471 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite {
test("count is partially aggregated") {
val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
- val planned = PartialAggregation(query).head
- val aggregations = planned.collect { case a: Aggregate => a }
+ val planned = HashAggregation(query).head
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggregations.size === 2)
}
test("count distinct is not partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
test("mixed aggregates are not partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index e55648b8ed..2cab5e0c44 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._
* Note: this is only a rough example of how TGFs can be expressed, the final version will likely
* involve a lot more sugar for cleaner use in Scala/Java/etc.
*/
-case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator {
+case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator {
def children = input
protected def makeOutput() = 'nameAndAge.string :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 3c911e9a4e..561f5b4a49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.util.Utils
@@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Projection of simple Parquet file") {
+ SparkPlan.currentContext.set(TestSQLContext)
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- Seq())(TestSQLContext)
+ Seq())
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 84d43eaeea..f0a61270da 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -231,7 +231,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveTableScans,
DataSinks,
Scripts,
- PartialAggregation,
+ HashAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index c2b0b00aa5..39033bdeac 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -131,7 +131,7 @@ case class InsertIntoHiveTable(
conf,
SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf))
- logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
+ log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
val writer = new SparkHiveHadoopWriter(conf, fileSinkConf)
writer.preSetup()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 8258ee5fef..0c8f676e9c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -67,7 +67,7 @@ case class ScriptTransformation(
}
}
readerThread.start()
- val outputProjection = new Projection(input)
+ val outputProjection = new InterpretedProjection(input, child.output)
iter
.map(outputProjection)
// TODO: Use SerDe
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 057eb60a02..7582b4743d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf(
@transient
protected lazy val function: GenericUDTF = createFunction()
+ @transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
+ @transient
protected lazy val outputInspectors = {
val structInspector = function.initialize(inputInspectors.toArray)
structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector)
@@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf(
override def eval(input: Row): TraversableOnce[Row] = {
outputInspectors // Make sure initialized.
- val inputProjection = new Projection(children)
+ val inputProjection = new InterpretedProjection(children)
val collector = new UDTFCollector
function.setCollector(collector)
@@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction(
override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
@transient
- val inputProjection = new Projection(exprs)
+ val inputProjection = new InterpretedProjection(exprs)
def update(input: Row): Unit = {
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d
new file mode 100644
index 0000000000..00750edc07
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d
@@ -0,0 +1 @@
+3
diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2
new file mode 100644
index 0000000000..00750edc07
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2
@@ -0,0 +1 @@
+3
diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index aadfd2e900..89cc589fb8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution
import scala.util.Try
+import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.{Row, SchemaRDD}
@@ -30,6 +32,15 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {
+ createQueryTest("single case",
+ """SELECT case when true then 1 else 2 end FROM src LIMIT 1""")
+
+ createQueryTest("double case",
+ """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""")
+
+ createQueryTest("case else null",
+ """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""")
+
createQueryTest("having no references",
"SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1")