aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pom.xml10
-rw-r--r--project/SparkBuild.scala11
-rw-r--r--sql/catalyst/pom.xml9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala468
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala98
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala219
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala80
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala200
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala138
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala6
-rw-r--r--sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d1
-rw-r--r--sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f21
-rw-r--r--sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e731
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala11
53 files changed, 1889 insertions, 297 deletions
diff --git a/pom.xml b/pom.xml
index 39538f9660..ae97bf03c5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -114,6 +114,7 @@
<sbt.project.name>spark</sbt.project.name>
<scala.version>2.10.4</scala.version>
<scala.binary.version>2.10</scala.binary.version>
+ <scala.macros.version>2.0.1</scala.macros.version>
<mesos.version>0.18.1</mesos.version>
<mesos.classifier>shaded-protobuf</mesos.classifier>
<akka.group>org.spark-project.akka</akka.group>
@@ -825,6 +826,15 @@
<javacArg>-target</javacArg>
<javacArg>${java.version}</javacArg>
</javacArgs>
+ <!-- The following plugin is required to use quasiquotes in Scala 2.10 and is used
+ by Spark SQL for code generation. -->
+ <compilerPlugins>
+ <compilerPlugin>
+ <groupId>org.scalamacros</groupId>
+ <artifactId>paradise_${scala.version}</artifactId>
+ <version>${scala.macros.version}</version>
+ </compilerPlugin>
+ </compilerPlugins>
</configuration>
</plugin>
<plugin>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 0a6326e722..490fac3cc3 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -167,6 +167,9 @@ object SparkBuild extends PomBuild {
/* Enable unidoc only for the root spark project */
enable(Unidoc.settings)(spark)
+ /* Catalyst macro settings */
+ enable(Catalyst.settings)(catalyst)
+
/* Spark SQL Core console settings */
enable(SQL.settings)(sql)
@@ -189,10 +192,13 @@ object Flume {
lazy val settings = sbtavro.SbtAvro.avroSettings
}
-object SQL {
-
+object Catalyst {
lazy val settings = Seq(
+ addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full))
+}
+object SQL {
+ lazy val settings = Seq(
initialCommands in console :=
"""
|import org.apache.spark.sql.catalyst.analysis._
@@ -207,7 +213,6 @@ object SQL {
|import org.apache.spark.sql.test.TestSQLContext._
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin
)
-
}
object Hive {
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 531bfddbf2..54fa96baa1 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -38,9 +38,18 @@
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
+ <artifactId>scala-compiler</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
<dependency>
+ <groupId>org.scalamacros</groupId>
+ <artifactId>quasiquotes_${scala.binary.version}</artifactId>
+ <version>${scala.macros.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5c8c810d91..f44521d638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -202,7 +202,7 @@ package object dsl {
// Protobuf terminology
def required = a.withNullability(false)
- def at(ordinal: Int) = BoundReference(ordinal, a)
+ def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 9ce1f01056..a3ebec8082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.trees
+
import org.apache.spark.sql.Logging
/**
@@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging
* to be retrieved more efficiently. However, since operations like column pruning can change
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
-case class BoundReference(ordinal: Int, baseReference: Attribute)
- extends Attribute with trees.LeafNode[Expression] {
+case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
+ extends Expression with trees.LeafNode[Expression] {
type EvaluatedType = Any
- override def nullable = baseReference.nullable
- override def dataType = baseReference.dataType
- override def exprId = baseReference.exprId
- override def qualifiers = baseReference.qualifiers
- override def name = baseReference.name
+ override def references = Set.empty
- override def newInstance = BoundReference(ordinal, baseReference.newInstance)
- override def withNullability(newNullability: Boolean) =
- BoundReference(ordinal, baseReference.withNullability(newNullability))
- override def withQualifiers(newQualifiers: Seq[String]) =
- BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
-
- override def toString = s"$baseReference:$ordinal"
+ override def toString = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
}
-/**
- * Used to denote operators that do their own binding of attributes internally.
- */
-trait NoBind { self: trees.TreeNode[_] => }
-
-class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
- import BindReferences._
-
- def apply(plan: TreeNode): TreeNode = {
- plan.transform {
- case n: NoBind => n.asInstanceOf[TreeNode]
- case leafNode if leafNode.children.isEmpty => leafNode
- case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
- bindReference(e, unaryNode.children.head.output)
- }
- }
- }
-}
-
object BindReferences extends Logging {
def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
- // TODO: This fallback is required because some operators (such as ScriptTransform)
- // produce new attributes that can't be bound. Likely the right thing to do is remove
- // this rule and require all operators to explicitly bind to the input schema that
- // they specify.
- logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
- a
+ sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
} else {
- BoundReference(ordinal, a)
+ BoundReference(ordinal, a.dataType, a.nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 2c71d2c7b3..8fc5896974 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
+
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
+ * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-class Projection(expressions: Seq[Expression]) extends (Row => Row) {
+class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
@@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
}
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of th
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
- *
- * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
- * each time an input row is added. This significantly reduces the cost of calculating the
- * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
- * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
- * and hold on to the returned [[Row]] before calling `next()`.
+ * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
+ * expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
+case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
private[this] val exprArray = expressions.toArray
- private[this] val mutableRow = new GenericMutableRow(exprArray.size)
+ private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size)
def currentValue: Row = mutableRow
- def apply(input: Row): Row = {
+ override def target(row: MutableRow): MutableProjection = {
+ mutableRow = row
+ this
+ }
+
+ override def apply(input: Row): Row = {
var i = 0
while (i < exprArray.length) {
mutableRow(i) = exprArray(i).eval(input)
@@ -76,6 +77,12 @@ class JoinedRow extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 74ae723686..7470cb861b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -88,15 +88,6 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
-
- /**
- * Experimental
- *
- * Returns a mutable string builder for the specified column. A given row should return the
- * result of any mutations made to the returned buffer next time getString is called for the same
- * column.
- */
- def getStringBuilder(ordinal: Int): StringBuilder
}
/**
@@ -180,6 +171,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
values(i).asInstanceOf[String]
}
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode(): Int = {
+ var result: Int = 37
+
+ var i = 0
+ while (i < values.length) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ apply(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+
def copy() = this
}
@@ -187,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)
- def getStringBuilder(ordinal: Int): StringBuilder = ???
-
override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 5e089f7618..acddf5e9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
override def eval(input: Row): Any = {
children.size match {
+ case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
new file mode 100644
index 0000000000..5b398695bf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.google.common.cache.{CacheLoader, CacheBuilder}
+
+import scala.language.existentials
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A base class for generators of byte code to perform expression evaluation. Includes a set of
+ * helpers for referring to Catalyst types and building trees that perform evaluation of individual
+ * expressions.
+ */
+abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ import scala.tools.reflect.ToolBox
+
+ protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
+
+ protected val rowType = typeOf[Row]
+ protected val mutableRowType = typeOf[MutableRow]
+ protected val genericRowType = typeOf[GenericRow]
+ protected val genericMutableRowType = typeOf[GenericMutableRow]
+
+ protected val projectionType = typeOf[Projection]
+ protected val mutableProjectionType = typeOf[MutableProjection]
+
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ private val javaSeparator = "$"
+
+ /**
+ * Generates a class for a given input expression. Called when there is not cached code
+ * already available.
+ */
+ protected def create(in: InType): OutType
+
+ /**
+ * Canonicalizes an input expression. Used to avoid double caching expressions that differ only
+ * cosmetically.
+ */
+ protected def canonicalize(in: InType): InType
+
+ /** Binds an input expression to a given input schema */
+ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
+
+ /**
+ * A cache of generated classes.
+ *
+ * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
+ * fundamental difference is that a ConcurrentMap persists all elements that are added to it until
+ * they are explicitly removed. A Cache on the other hand is generally configured to evict entries
+ * automatically, in order to constrain its memory footprint
+ */
+ protected val cache = CacheBuilder.newBuilder()
+ .maximumSize(1000)
+ .build(
+ new CacheLoader[InType, OutType]() {
+ override def load(in: InType): OutType = globalLock.synchronized {
+ create(in)
+ }
+ })
+
+ /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
+ def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType =
+ apply(bind(expressions, inputSchema))
+
+ /** Generates the requested evaluator given already bound expression(s). */
+ def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))
+
+ /**
+ * Returns a term name that is unique within this instance of a `CodeGenerator`.
+ *
+ * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
+ * function.)
+ */
+ protected def freshName(prefix: String): TermName = {
+ newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
+ }
+
+ /**
+ * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
+ *
+ * @param code The sequence of statements required to evaluate the expression.
+ * @param nullTerm A term that holds a boolean value representing whether the expression evaluated
+ * to null.
+ * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
+ * valid if `nullTerm` is set to `false`.
+ * @param objectTerm A possibly boxed version of the result of evaluating this expression.
+ */
+ protected case class EvaluatedExpression(
+ code: Seq[Tree],
+ nullTerm: TermName,
+ primitiveTerm: TermName,
+ objectTerm: TermName)
+
+ /**
+ * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
+ * can be used to determine the result of evaluating the expression on an input row.
+ */
+ def expressionEvaluator(e: Expression): EvaluatedExpression = {
+ val primitiveTerm = freshName("primitiveTerm")
+ val nullTerm = freshName("nullTerm")
+ val objectTerm = freshName("objectTerm")
+
+ implicit class Evaluate1(e: Expression) {
+ def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${f(eval.primitiveTerm)}
+ """.children
+ }
+ }
+
+ implicit class Evaluate2(expressions: (Expression, Expression)) {
+
+ /**
+ * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
+ * the same type. If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f a function from two primitive term names to a tree that evaluates them.
+ */
+ def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ evaluateAs(expressions._1.dataType)(f)
+
+ def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = {
+ // TODO: Right now some timestamp tests fail if we enforce this...
+ if (expressions._1.dataType != expressions._2.dataType) {
+ log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
+ }
+
+ val eval1 = expressionEvaluator(expressions._1)
+ val eval2 = expressionEvaluator(expressions._2)
+ val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
+ val $primitiveTerm: ${termForType(resultType)} =
+ if($nullTerm) {
+ ${defaultPrimitive(resultType)}
+ } else {
+ $resultCode.asInstanceOf[${termForType(resultType)}]
+ }
+ """.children : Seq[Tree]
+ }
+ }
+
+ val inputTuple = newTermName(s"i")
+
+ // TODO: Skip generation of null handling code when expression are not nullable.
+ val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ case b @ BoundReference(ordinal, dataType, nullable) =>
+ val nullValue = q"$inputTuple.isNullAt($ordinal)"
+ q"""
+ val $nullTerm: Boolean = $nullValue
+ val $primitiveTerm: ${termForType(dataType)} =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${getColumn(inputTuple, dataType, ordinal)}
+ """.children
+
+ case expressions.Literal(null, dataType) =>
+ q"""
+ val $nullTerm = true
+ val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}]
+ """.children
+
+ case expressions.Literal(value: Boolean, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: String, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Int, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Long, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case Cast(e @ BinaryType(), StringType) =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ """.children
+
+ case Cast(child @ NumericType(), IntegerType) =>
+ child.castOrNull(c => q"$c.toInt", IntegerType)
+
+ case Cast(child @ NumericType(), LongType) =>
+ child.castOrNull(c => q"$c.toLong", LongType)
+
+ case Cast(child @ NumericType(), DoubleType) =>
+ child.castOrNull(c => q"$c.toDouble", DoubleType)
+
+ case Cast(child @ NumericType(), FloatType) =>
+ child.castOrNull(c => q"$c.toFloat", IntegerType)
+
+ // Special handling required for timestamps in hive test cases since the toString function
+ // does not match the expected output.
+ case Cast(e, StringType) if e.dataType != TimestampType =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ ${eval.primitiveTerm}.toString
+ """.children
+
+ case EqualTo(e1, e2) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
+
+ /* TODO: Fix null semantics.
+ case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) =>
+ val eval = expressionEvaluator(e1)
+
+ val checks = list.map {
+ case expressions.Literal(v: String, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ case expressions.Literal(v: Int, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ }
+
+ val funcName = newTermName(s"isIn${curId.getAndIncrement()}")
+
+ q"""
+ def $funcName: Boolean = {
+ ..${eval.code}
+ if(${eval.nullTerm}) return false
+ ..$checks
+ return false
+ }
+ val $nullTerm = false
+ val $primitiveTerm = $funcName
+ """.children
+ */
+
+ case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
+ case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
+ case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
+ case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }
+
+ case And(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && !${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = false
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = true
+ }
+ """.children
+
+ case Or(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && ${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = true
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = false
+ }
+ """.children
+
+ case Not(child) =>
+ // Uh, bad function name...
+ child.castOrNull(c => q"!$c", BooleanType)
+
+ case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
+ case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
+ case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
+ case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+
+ case IsNotNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm}
+ """.children
+
+ case IsNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}
+ """.children
+
+ case c @ Coalesce(children) =>
+ q"""
+ var $nullTerm = true
+ var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
+ """.children ++
+ children.map { c =>
+ val eval = expressionEvaluator(c)
+ q"""
+ if($nullTerm) {
+ ..${eval.code}
+ if(!${eval.nullTerm}) {
+ $nullTerm = false
+ $primitiveTerm = ${eval.primitiveTerm}
+ }
+ }
+ """
+ }
+
+ case i @ expressions.If(condition, trueValue, falseValue) =>
+ val condEval = expressionEvaluator(condition)
+ val trueEval = expressionEvaluator(trueValue)
+ val falseEval = expressionEvaluator(falseValue)
+
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)}
+ ..${condEval.code}
+ if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
+ ..${trueEval.code}
+ $nullTerm = ${trueEval.nullTerm}
+ $primitiveTerm = ${trueEval.primitiveTerm}
+ } else {
+ ..${falseEval.code}
+ $nullTerm = ${falseEval.nullTerm}
+ $primitiveTerm = ${falseEval.primitiveTerm}
+ }
+ """.children
+ }
+
+ // If there was no match in the partial function above, we fall back on calling the interpreted
+ // expression evaluator.
+ val code: Seq[Tree] =
+ primitiveEvaluation.lift.apply(e).getOrElse {
+ log.debug(s"No rules to generate $e")
+ val tree = reify { e }
+ q"""
+ val $objectTerm = $tree.eval(i)
+ val $nullTerm = $objectTerm == null
+ val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
+ """.children
+ }
+
+ EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
+ }
+
+ protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
+ dataType match {
+ case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
+ case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
+ }
+ }
+
+ protected def setColumn(
+ destinationRow: TermName,
+ dataType: DataType,
+ ordinal: Int,
+ value: TermName) = {
+ dataType match {
+ case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case _ => q"$destinationRow.update($ordinal, $value)"
+ }
+ }
+
+ protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
+ protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+
+ protected def primitiveForType(dt: DataType) = dt match {
+ case IntegerType => "Int"
+ case LongType => "Long"
+ case ShortType => "Short"
+ case ByteType => "Byte"
+ case DoubleType => "Double"
+ case FloatType => "Float"
+ case BooleanType => "Boolean"
+ case StringType => "String"
+ }
+
+ protected def defaultPrimitive(dt: DataType) = dt match {
+ case BooleanType => ru.Literal(Constant(false))
+ case FloatType => ru.Literal(Constant(-1.0.toFloat))
+ case StringType => ru.Literal(Constant("<uninit>"))
+ case ShortType => ru.Literal(Constant(-1.toShort))
+ case LongType => ru.Literal(Constant(1L))
+ case ByteType => ru.Literal(Constant(-1.toByte))
+ case DoubleType => ru.Literal(Constant(-1.toDouble))
+ case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+ case IntegerType => ru.Literal(Constant(-1))
+ case _ => ru.Literal(Constant(null))
+ }
+
+ protected def termForType(dt: DataType) = dt match {
+ case n: NativeType => n.tag
+ case _ => typeTag[Any]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
new file mode 100644
index 0000000000..a419fd7ecb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[Row]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ val mutableRowName = newTermName("mutableRow")
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
+ val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
+ val evaluationCode = expressionEvaluator(e)
+
+ evaluationCode.code :+
+ q"""
+ if(${evaluationCode.nullTerm})
+ mutableRow.setNullAt($i)
+ else
+ ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)}
+ """
+ }
+
+ val code =
+ q"""
+ () => { new $mutableProjectionType {
+
+ private[this] var $mutableRowName: $mutableRowType =
+ new $genericMutableRowType(${expressions.size})
+
+ def target(row: $mutableRowType): $mutableProjectionType = {
+ $mutableRowName = row
+ this
+ }
+
+ /* Provide immutable access to the last projected row. */
+ def currentValue: $rowType = mutableRow
+
+ def apply(i: $rowType): $rowType = {
+ ..$projectionCode
+ mutableRow
+ }
+ } }
+ """
+
+ log.debug(s"code for ${expressions.mkString(",")}:\n$code")
+ toolBox.eval(code).asInstanceOf[() => MutableProjection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
new file mode 100644
index 0000000000..4211998f75
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.typesafe.scalalogging.slf4j.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.{StringType, NumericType}
+
+/**
+ * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
+ * [[Expression Expressions]].
+ */
+object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
+ in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder])
+
+ protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
+ val a = newTermName("a")
+ val b = newTermName("b")
+ val comparisons = ordering.zipWithIndex.map { case (order, i) =>
+ val evalA = expressionEvaluator(order.child)
+ val evalB = expressionEvaluator(order.child)
+
+ val compare = order.child.dataType match {
+ case _: NumericType =>
+ q"""
+ val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
+ if(comp != 0) {
+ return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
+ }
+ """
+ case StringType =>
+ if (order.direction == Ascending) {
+ q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
+ } else {
+ q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
+ }
+ }
+
+ q"""
+ i = $a
+ ..${evalA.code}
+ i = $b
+ ..${evalB.code}
+ if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ // Nothing
+ } else if (${evalA.nullTerm}) {
+ return ${if (order.direction == Ascending) q"-1" else q"1"}
+ } else if (${evalB.nullTerm}) {
+ return ${if (order.direction == Ascending) q"1" else q"-1"}
+ } else {
+ $compare
+ }
+ """
+ }
+
+ val q"class $orderingName extends $orderingType { ..$body }" = reify {
+ class SpecificOrdering extends Ordering[Row] {
+ val o = ordering
+ }
+ }.tree.children.head
+
+ val code = q"""
+ class $orderingName extends $orderingType {
+ ..$body
+ def compare(a: $rowType, b: $rowType): Int = {
+ var i: $rowType = null // Holds current row being evaluated.
+ ..$comparisons
+ return 0
+ }
+ }
+ new $orderingName()
+ """
+ logger.debug(s"Generated Ordering: $code")
+ toolBox.eval(code).asInstanceOf[Ordering[Row]]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
new file mode 100644
index 0000000000..2a0935c790
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]].
+ */
+object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in)
+
+ protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
+ BindReferences.bindReference(in, inputSchema)
+
+ protected def create(predicate: Expression): ((Row) => Boolean) = {
+ val cEval = expressionEvaluator(predicate)
+
+ val code =
+ q"""
+ (i: $rowType) => {
+ ..${cEval.code}
+ if (${cEval.nullTerm}) false else ${cEval.primitiveTerm}
+ }
+ """
+
+ log.debug(s"Generated predicate '$predicate':\n$code")
+ toolBox.eval(code).asInstanceOf[Row => Boolean]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
new file mode 100644
index 0000000000..77fa02c13d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+
+/**
+ * Generates bytecode that produces a new [[Row]] object based on a fixed set of input
+ * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom
+ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values.
+ */
+object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ // Make Mutablility optional...
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val tupleLength = ru.Literal(Constant(expressions.length))
+ val lengthDef = q"final val length = $tupleLength"
+
+ /* TODO: Configurable...
+ val nullFunctions =
+ q"""
+ private final val nullSet = new org.apache.spark.util.collection.BitSet(length)
+ final def setNullAt(i: Int) = nullSet.set(i)
+ final def isNullAt(i: Int) = nullSet.get(i)
+ """
+ */
+
+ val nullFunctions =
+ q"""
+ private[this] var nullBits = new Array[Boolean](${expressions.size})
+ final def setNullAt(i: Int) = { nullBits(i) = true }
+ final def isNullAt(i: Int) = nullBits(i)
+ """.children
+
+ val tupleElements = expressions.zipWithIndex.flatMap {
+ case (e, i) =>
+ val elementName = newTermName(s"c$i")
+ val evaluatedExpression = expressionEvaluator(e)
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _
+ {
+ ..${evaluatedExpression.code}
+ if(${evaluatedExpression.nullTerm})
+ setNullAt($iLit)
+ else
+ $elementName = ${evaluatedExpression.primitiveTerm}
+ }
+ """.children : Seq[Tree]
+ }
+
+ val iteratorFunction = {
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+ q"final def iterator = Iterator[Any](..$allColumns)"
+ }
+
+ val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
+ val applyFunction = {
+ val cases = (0 until expressions.size).map { i =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }"
+ }
+ q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }"
+ }
+
+ val updateFunction = {
+ val cases = expressions.zipWithIndex.map {case (e, i) =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ if(i == $ordinal) {
+ if(value == null) {
+ setNullAt(i)
+ } else {
+ $elementName = value.asInstanceOf[${termForType(e.dataType)}]
+ return
+ }
+ }"""
+ }
+ q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
+ }
+
+ val specificAccessorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) return $elementName" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val specificMutatorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) { $elementName = value; return }" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val hashValues = expressions.zipWithIndex.map { case (e,i) =>
+ val elementName = newTermName(s"c$i")
+ val nonNull = e.dataType match {
+ case BooleanType => q"if ($elementName) 0 else 1"
+ case ByteType | ShortType | IntegerType => q"$elementName.toInt"
+ case LongType => q"($elementName ^ ($elementName >>> 32)).toInt"
+ case FloatType => q"java.lang.Float.floatToIntBits($elementName)"
+ case DoubleType =>
+ q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }"
+ case _ => q"$elementName.hashCode"
+ }
+ q"if (isNullAt($i)) 0 else $nonNull"
+ }
+
+ val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree)
+
+ val hashCodeFunction =
+ q"""
+ override def hashCode(): Int = {
+ var result: Int = 37
+ ..$hashUpdates
+ result
+ }
+ """
+
+ val columnChecks = (0 until expressions.size).map { i =>
+ val elementName = newTermName(s"c$i")
+ q"if (this.$elementName != specificType.$elementName) return false"
+ }
+
+ val equalsFunction =
+ q"""
+ override def equals(other: Any): Boolean = other match {
+ case specificType: SpecificRow =>
+ ..$columnChecks
+ return true
+ case other => super.equals(other)
+ }
+ """
+
+ val copyFunction =
+ q"""
+ final def copy() = new $genericRowType(this.toArray)
+ """
+
+ val classBody =
+ nullFunctions ++ (
+ lengthDef +:
+ iteratorFunction +:
+ applyFunction +:
+ updateFunction +:
+ equalsFunction +:
+ hashCodeFunction +:
+ copyFunction +:
+ (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
+
+ val code = q"""
+ final class SpecificRow(i: $rowType) extends $mutableRowType {
+ ..$classBody
+ }
+
+ new $projectionType { def apply(r: $rowType) = new SpecificRow(r) }
+ """
+
+ log.debug(
+ s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}")
+ toolBox.eval(code).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
new file mode 100644
index 0000000000..80c7dfd376
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.rules
+import org.apache.spark.sql.catalyst.util
+
+/**
+ * A collection of generators that build custom bytecode at runtime for performing the evaluation
+ * of catalyst expression.
+ */
+package object codegen {
+
+ /**
+ * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala
+ * 2.10.
+ */
+ protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+ /** Canonicalizes an expression so those that differ only by names can reuse the same code. */
+ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
+ val batches =
+ Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil
+
+ object CleanExpressions extends rules.Rule[Expression] {
+ def apply(e: Expression): Expression = e transform {
+ case Alias(c, _) => c
+ }
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Dumps the bytecode from a class to the screen using javap.
+ */
+ @DeveloperApi
+ object DumpByteCode {
+ import scala.sys.process._
+ val dumpDirectory = util.getTempFilePath("sparkSqlByteCode")
+ dumpDirectory.mkdir()
+
+ def apply(obj: Any): Unit = {
+ val generatedClass = obj.getClass
+ val classLoader =
+ generatedClass
+ .getClassLoader
+ .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader]
+ val generatedBytes = classLoader.classBytes(generatedClass.getName)
+
+ val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName)
+ if (!packageDir.exists()) { packageDir.mkdir() }
+
+ val classFile =
+ new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class")
+
+ val outfile = new java.io.FileOutputStream(classFile)
+ outfile.write(generatedBytes)
+ outfile.close()
+
+ println(
+ s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index b6f2451b52..55d95991c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst
* ==Evaluation==
* The result of expressions can be evaluated using the `Expression.apply(Row)` method.
*/
-package object expressions
+package object expressions {
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ */
+ abstract class Projection extends (Row => Row)
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ *
+ * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
+ * each time an input row is added. This significantly reduces the cost of calculating the
+ * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
+ * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
+ * and hold on to the returned [[Row]] before calling `next()`.
+ */
+ abstract class MutableProjection extends Projection {
+ def currentValue: Row
+
+ /** Uses the given row to store the output of the projection. */
+ def target(row: MutableRow): MutableProjection
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 06b94a98d3..5976b0ddf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType
object InterpretedPredicate {
+ def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
+ apply(BindReferences.bindReference(expression, inputSchema))
+
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.eval(r).asInstanceOf[Boolean]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
new file mode 100644
index 0000000000..3b3e206055
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+package object catalyst {
+ /**
+ * A JVM-global lock that should be used to prevent thread safety issues when using things in
+ * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
+ * 2.10.* builds. See SI-6240 for more details.
+ */
+ protected[catalyst] object ScalaReflectionLock
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 026692abe0..418f8686bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -105,6 +105,77 @@ object PhysicalOperation extends PredicateHelper {
}
/**
+ * Matches a logical aggregation that can be performed on distributed data in two steps. The first
+ * operates on the data in each partition performing partial aggregation for each group. The second
+ * occurs after the shuffle and completes the aggregation.
+ *
+ * This pattern will only match if all aggregate expressions can be computed partially and will
+ * return the rewritten aggregation expressions for both phases.
+ *
+ * The returned values for this match are as follows:
+ * - Grouping attributes for the final aggregation.
+ * - Aggregates for the final aggregation.
+ * - Grouping expressions for the partial aggregation.
+ * - Partial aggregate expressions.
+ * - Input to the aggregation.
+ */
+object PartialAggregation {
+ type ReturnType =
+ (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ // Collect all aggregate expressions.
+ val allAggregates =
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ // Collect all aggregate expressions that can be computed partially.
+ val partialAggregates =
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+
+ // Only do partial aggregation if supported by all aggregate expressions.
+ if (allAggregates.size == partialAggregates.size) {
+ // Create a map of expressions to their partial evaluations for all aggregate expressions.
+ val partialEvaluations: Map[Long, SplitEvaluation] =
+ partialAggregates.map(a => (a.id, a.asPartial)).toMap
+
+ // We need to pass all grouping expressions though so the grouping can happen a second
+ // time. However some of them might be unnamed so we alias them allowing them to be
+ // referenced in the second aggregation.
+ val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
+ case n: NamedExpression => (n, n)
+ case other => (other, Alias(other, "PartialGroup")())
+ }.toMap
+
+ // Replace aggregations with a new expression that computes the result from the already
+ // computed partial evaluations and grouping values.
+ val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+ case e: Expression if partialEvaluations.contains(e.id) =>
+ partialEvaluations(e.id).finalEvaluation
+ case e: Expression if namedGroupingExpressions.contains(e) =>
+ namedGroupingExpressions(e).toAttribute
+ }).asInstanceOf[Seq[NamedExpression]]
+
+ val partialComputation =
+ (namedGroupingExpressions.values ++
+ partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+
+ val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
+
+ Some(
+ (namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child))
+ } else {
+ None
+ }
+ case _ => None
+ }
+}
+
+
+/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ac85f95b52..888cb08e95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -112,7 +112,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
self: Product =>
override lazy val statistics: Statistics =
- throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics")
+ throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
// Leaf nodes by definition cannot reference any input attributes.
override def references = Set.empty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index a357c6ffb8..481a5a4f21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -35,7 +35,7 @@ abstract class Command extends LeafNode {
*/
case class NativeCommand(cmd: String) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)()))
+ Seq(AttributeReference("result", StringType, nullable = false)())
}
/**
@@ -43,7 +43,7 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(1, AttributeReference("", StringType, nullable = false)()))
+ AttributeReference("", StringType, nullable = false)())
}
/**
@@ -52,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman
*/
case class ExplainCommand(plan: LogicalPlan) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)()))
+ Seq(AttributeReference("plan", StringType, nullable = false)())
}
/**
@@ -71,7 +71,7 @@ case class DescribeCommand(
isExtended: Boolean) extends Command {
override def output = Seq(
// Column names are based on Hive.
- BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()),
- BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()),
- BoundReference(2, AttributeReference("comment", StringType, nullable = false)()))
+ AttributeReference("col_name", StringType, nullable = false)(),
+ AttributeReference("data_type", StringType, nullable = false)(),
+ AttributeReference("comment", StringType, nullable = false)())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index e32adb76fe..e300bdbece 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -72,7 +72,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
}
iteration += 1
if (iteration > batch.strategy.maxIterations) {
- logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}")
+ // Only log if this is a rule that is supposed to run more than once.
+ if (iteration != 2) {
+ logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
+ }
continue = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index cd4b5e9c1b..71808f76d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -23,16 +23,13 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
import scala.util.parsing.combinator.RegexParsers
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
/**
- * A JVM-global lock that should be used to prevent thread safety issues when using things in
- * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
- * 2.10.* builds. See SI-6240 for more details.
+ * Utility functions for working with DataTypes.
*/
-protected[catalyst] object ScalaReflectionLock
-
object DataType extends RegexParsers {
protected lazy val primitiveType: Parser[DataType] =
"StringType" ^^^ StringType |
@@ -99,6 +96,13 @@ abstract class DataType {
case object NullType extends DataType
+object NativeType {
+ def all = Seq(
+ IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+
+ def unapply(dt: DataType): Boolean = all.contains(dt)
+}
+
trait PrimitiveType extends DataType {
override def isPrimitive = true
}
@@ -149,6 +153,10 @@ abstract class NumericType extends NativeType with PrimitiveType {
val numeric: Numeric[JvmType]
}
+object NumericType {
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
+}
+
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
object IntegralType {
def unapply(a: Expression): Boolean = a match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 58f8c341e6..999c9fff38 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
class ExpressionEvaluationSuite extends FunSuite {
test("literals") {
- assert((Literal(1) + Literal(1)).eval(null) === 2)
+ checkEvaluation(Literal(1), 1)
+ checkEvaluation(Literal(true), true)
+ checkEvaluation(Literal(0L), 0L)
+ checkEvaluation(Literal("test"), "test")
+ checkEvaluation(Literal(1) + Literal(1), 2)
}
/**
@@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite {
test("3VL Not") {
notTrueTable.foreach {
case (v, answer) =>
- val expr = ! Literal(v, BooleanType)
- val result = expr.eval(null)
- if (result != answer)
- fail(s"$expr should not evaluate to $result, expected: $answer") }
+ checkEvaluation(!Literal(v, BooleanType), answer)
+ }
}
booleanLogicTest("AND", _ && _,
@@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite {
}
}
+ test("IN") {
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
@@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
- checkEvaluation("23" cast DoubleType, 23)
+ checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
- checkEvaluation("23" cast FloatType, 23)
- checkEvaluation("23" cast DecimalType, 23)
- checkEvaluation("23" cast ByteType, 23)
- checkEvaluation("23" cast ShortType, 23)
+ checkEvaluation("23" cast FloatType, 23f)
+ checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast ByteType, 23.toByte)
+ checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)
- checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+ checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
- checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
- checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
- checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+ checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
+ checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
+ checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
@@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite {
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(null, IntegerType)), null, row)
- checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
+ checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
val typeS_notNullable = StructType(
@@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite {
:: StructField("b", StringType, nullable = false) :: Nil
)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS)()), "a").nullable === true)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false)
+ assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
+ assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
new file mode 100644
index 0000000000..245a2e1480
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use code generation for evaluation.
+ */
+class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ val plan = try {
+ GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)()
+ } catch {
+ case e: Throwable =>
+ val evaluated = GenerateProjection.expressionEvaluator(expression)
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow).apply(0)
+ if(actual != expected) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+
+
+ test("multithreaded eval") {
+ import scala.concurrent._
+ import ExecutionContext.Implicits.global
+ import scala.concurrent.duration._
+
+ val futures = (1 to 20).map { _ =>
+ future {
+ GeneratePredicate(EqualTo(Literal(1), Literal(1)))
+ GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil)
+ }
+ }
+
+ futures.foreach(Await.result(_, 10.seconds))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
new file mode 100644
index 0000000000..887aabb1d5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use generated code on mutable rows.
+ */
+class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
+
+ val plan = try {
+ GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
+ } catch {
+ case e: Throwable =>
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow)
+ val expectedRow = new GenericRow(Array[Any](expected))
+ if (actual.hashCode() != expectedRow.hashCode()) {
+ fail(
+ s"""
+ |Mismatched hashCodes for values: $actual, $expectedRow
+ |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
+ |${evaluated.code.mkString("\n")}
+ """.stripMargin)
+ }
+ if (actual != expectedRow) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 4896f1b955..e2ae0d25db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("Combine Limit", FixedPoint(2),
+ Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
- Batch("Constant Folding", FixedPoint(3),
+ Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
BooleanSimplification) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 5d85a0fd4e..2d407077be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -24,8 +24,11 @@ import scala.collection.JavaConverters._
object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
- val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
+ val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size"
+ val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
+ val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables"
+ val CODEGEN_ENABLED = "spark.sql.codegen"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -57,6 +60,18 @@ trait SQLConf {
private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt
/**
+ * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
+ * that evaluates expressions found in queries. In general this custom code runs much faster
+ * than interpreted evaluation, but there are significant start-up costs due to compilation.
+ * As a result codegen is only benificial when queries run for a long time, or when the same
+ * expressions are used multiple times.
+ *
+ * Defaults to false as this feature is currently experimental.
+ */
+ private[spark] def codegenEnabled: Boolean =
+ if (get(CODEGEN_ENABLED, "false") == "true") true else false
+
+ /**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
* effectively disables auto conversion.
@@ -111,5 +126,5 @@ trait SQLConf {
private[spark] def clear() {
settings.clear()
}
-
}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c2bdef7323..e4b6810180 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration)))
+ new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this))
/**
* Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]].
@@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
conf: Configuration = new Configuration()): SchemaRDD = {
new SchemaRDD(
this,
- ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf))
+ ParquetRelation.createEmpty(
+ path, ScalaReflection.attributesFor[A], allowExisting, conf, this))
}
/**
@@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
val sqlContext: SQLContext = self
+ def codegenEnabled = self.codegenEnabled
+
def numPartitions = self.numShufflePartitions
val strategies: Seq[Strategy] =
CommandStrategy(self) ::
TakeOrdered ::
- PartialAggregation ::
+ HashAggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
@@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1)
/**
- * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
- * inserting shuffle operations as needed.
+ * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed.
*/
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange(self)) ::
- Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil
+ Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
/**
+ * :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
* access to the intermediate phases of query execution for developers.
*/
+ @DeveloperApi
protected abstract class QueryExecution {
def logical: LogicalPlan
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
- lazy val sparkPlan = planner(optimizedPlan).next()
+ lazy val sparkPlan = {
+ SparkPlan.currentContext.set(self)
+ planner(optimizedPlan).next()
+ }
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
@@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
|${stringOrError(optimizedPlan)}
|== Physical Plan ==
|${stringOrError(executedPlan)}
+ |Code Generation: ${executedPlan.codegenEnabled}
+ |== RDD ==
+ |${stringOrError(toRdd.toDebugString)}
""".stripMargin.trim
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 806097c917..85726bae54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
conf: Configuration = new Configuration()): JavaSchemaRDD = {
new JavaSchemaRDD(
sqlContext,
- ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf))
+ ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext))
}
/**
@@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) {
def parquetFile(path: String): JavaSchemaRDD =
new JavaSchemaRDD(
sqlContext,
- ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration)))
+ ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext))
/**
* Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index c1ced8bfa4..463a1d32d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -42,8 +42,8 @@ case class Aggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)(@transient sqlContext: SQLContext)
- extends UnaryNode with NoBind {
+ child: SparkPlan)
+ extends UnaryNode {
override def requiredChildDistribution =
if (partial) {
@@ -56,8 +56,6 @@ case class Aggregate(
}
}
- override def otherCopyArgs = sqlContext :: Nil
-
// HACK: Generators don't correctly preserve their output through serializations so we grab
// out child's output attributes statically here.
private[this] val childOutput = child.output
@@ -138,7 +136,7 @@ case class Aggregate(
i += 1
}
}
- val resultProjection = new Projection(resultExpressions, computedSchema)
+ val resultProjection = new InterpretedProjection(resultExpressions, computedSchema)
val aggregateResults = new GenericMutableRow(computedAggregates.length)
var i = 0
@@ -152,7 +150,7 @@ case class Aggregate(
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
- val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
+ val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput)
var currentRow: Row = null
while (iter.hasNext) {
@@ -175,7 +173,8 @@ case class Aggregate(
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
- new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
+ new InterpretedMutableProjection(
+ resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow
override final def hasNext: Boolean = hashTableIter.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 00010ef6e7..392a7f3be3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
@@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind {
+case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
override def outputPartitioning = newPartitioning
@@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
val rdd = child.execute().mapPartitions { iter =>
- val hashExpressions = new MutableProjection(expressions, child.output)
+ @transient val hashExpressions =
+ newMutableProjection(expressions, child.output)()
+
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 47b3d00262..c386fd121c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -47,23 +47,26 @@ case class Generate(
}
}
- override def output =
+ // This must be a val since the generator output expr ids are not preserved by serialization.
+ override val output =
if (join) child.output ++ generatorOutput else generatorOutput
+ val boundGenerator = BindReferences.bindReference(generator, child.output)
+
override def execute() = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
- new Projection(child.output ++ nullValues, child.output)
+ newProjection(child.output ++ nullValues, child.output)
val joinProjection =
- new Projection(child.output ++ generator.output, child.output ++ generator.output)
+ newProjection(child.output ++ generator.output, child.output ++ generator.output)
val joinedRow = new JoinedRow
iter.flatMap {row =>
- val outputRows = generator.eval(row)
+ val outputRows = boundGenerator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
@@ -72,7 +75,7 @@ case class Generate(
}
}
} else {
- child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
+ child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
new file mode 100644
index 0000000000..4a26934c49
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types._
+
+case class AggregateEvaluation(
+ schema: Seq[Attribute],
+ initialValues: Seq[Expression],
+ update: Seq[Expression],
+ result: Expression)
+
+/**
+ * :: DeveloperApi ::
+ * Alternate version of aggregation that leverages projection and thus code generation.
+ * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
+ * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
+ *
+ * @param partial if true then aggregation is done partially on local data without shuffling to
+ * ensure all values where `groupingExpressions` are equal are present.
+ * @param groupingExpressions expressions that are evaluated to determine grouping.
+ * @param aggregateExpressions expressions that are computed for each group.
+ * @param child the input data source.
+ */
+@DeveloperApi
+case class GeneratedAggregate(
+ partial: Boolean,
+ groupingExpressions: Seq[Expression],
+ aggregateExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override def requiredChildDistribution =
+ if (partial) {
+ UnspecifiedDistribution :: Nil
+ } else {
+ if (groupingExpressions == Nil) {
+ AllTuples :: Nil
+ } else {
+ ClusteredDistribution(groupingExpressions) :: Nil
+ }
+ }
+
+ override def output = aggregateExpressions.map(_.toAttribute)
+
+ override def execute() = {
+ val aggregatesToCompute = aggregateExpressions.flatMap { a =>
+ a.collect { case agg: AggregateExpression => agg}
+ }
+
+ val computeFunctions = aggregatesToCompute.map {
+ case c @ Count(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val initialValue = Literal(0L)
+ val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val result = currentCount
+
+ AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case Sum(expr) =>
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialValue = Cast(Literal(0L), expr.dataType)
+
+ // Coalasce avoids double calculation...
+ // but really, common sub expression elimination would be better....
+ val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+ val result = currentSum
+
+ AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
+ case a @ Average(expr) =>
+ val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
+ val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
+ val initialCount = Literal(0L)
+ val initialSum = Cast(Literal(0L), expr.dataType)
+ val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
+
+ val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+
+ AggregateEvaluation(
+ currentCount :: currentSum :: Nil,
+ initialCount :: initialSum :: Nil,
+ updateCount :: updateSum :: Nil,
+ result
+ )
+ }
+
+ val computationSchema = computeFunctions.flatMap(_.schema)
+
+ val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
+ case (agg, func) => agg.id -> func.result
+ }.toMap
+
+ val namedGroups = groupingExpressions.zipWithIndex.map {
+ case (ne: NamedExpression, _) => (ne, ne)
+ case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+ }
+
+ val groupMap: Map[Expression, Attribute] =
+ namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
+
+ // The set of expressions that produce the final output given the aggregation buffer and the
+ // grouping expressions.
+ val resultExpressions = aggregateExpressions.map(_.transform {
+ case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+ case e: Expression if groupMap.contains(e) => groupMap(e)
+ })
+
+ child.execute().mapPartitions { iter =>
+ // Builds a new custom class for holding the results of aggregation for a group.
+ val initialValues = computeFunctions.flatMap(_.initialValues)
+ val newAggregationBuffer = newProjection(initialValues, child.output)
+ log.info(s"Initial values: ${initialValues.mkString(",")}")
+
+ // A projection that computes the group given an input tuple.
+ val groupProjection = newProjection(groupingExpressions, child.output)
+ log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
+
+ // A projection that is used to update the aggregate values for a group given a new tuple.
+ // This projection should be targeted at the current values for the group and then applied
+ // to a joined row of the current values with the new input row.
+ val updateExpressions = computeFunctions.flatMap(_.update)
+ val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
+ val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
+ log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
+
+ // A projection that produces the final result, given a computation.
+ val resultProjectionBuilder =
+ newMutableProjection(
+ resultExpressions,
+ (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+ log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
+
+ val joinedRow = new JoinedRow
+
+ if (groupingExpressions.isEmpty) {
+ // TODO: Codegening anything other than the updateProjection is probably over kill.
+ val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ var currentRow: Row = null
+ updateProjection.target(buffer)
+
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ updateProjection(joinedRow(buffer, currentRow))
+ }
+
+ val resultProjection = resultProjectionBuilder()
+ Iterator(resultProjection(buffer))
+ } else {
+ val buffers = new java.util.HashMap[Row, MutableRow]()
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ val currentGroup = groupProjection(currentRow)
+ var currentBuffer = buffers.get(currentGroup)
+ if (currentBuffer == null) {
+ currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
+ buffers.put(currentGroup, currentBuffer)
+ }
+ // Target the projection at the current aggregation buffer and then project the updated
+ // values.
+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
+ }
+
+ new Iterator[Row] {
+ private[this] val resultIterator = buffers.entrySet.iterator()
+ private[this] val resultProjection = resultProjectionBuilder()
+
+ def hasNext = resultIterator.hasNext
+
+ def next() = {
+ val currentGroup = resultIterator.next()
+ resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 77c874d031..21cbbc9772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -18,22 +18,55 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Logging, Row, SQLContext}
+
+
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
+
+object SparkPlan {
+ protected[sql] val currentContext = new ThreadLocal[SQLContext]()
+}
+
/**
* :: DeveloperApi ::
*/
@DeveloperApi
-abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
+abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
self: Product =>
+ /**
+ * A handle to the SQL Context that was used to create this plan. Since many operators need
+ * access to the sqlContext for RDD operations or configuration this field is automatically
+ * populated by the query planning infrastructure.
+ */
+ @transient
+ protected val sqlContext = SparkPlan.currentContext.get()
+
+ protected def sparkContext = sqlContext.sparkContext
+
+ // sqlContext will be null when we are being deserialized on the slaves. In this instance
+ // the value of codegenEnabled will be set by the desserializer after the constructor has run.
+ val codegenEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.codegenEnabled
+ } else {
+ false
+ }
+
+ /** Overridden make copy also propogates sqlContext to copied plan. */
+ override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ SparkPlan.currentContext.set(sqlContext)
+ super.makeCopy(newArgs)
+ }
+
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
@@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
*/
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
- protected def buildRow(values: Seq[Any]): Row =
- new GenericRow(values.toArray)
+ protected def newProjection(
+ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
+ log.debug(
+ s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if (codegenEnabled) {
+ GenerateProjection(expressions, inputSchema)
+ } else {
+ new InterpretedProjection(expressions, inputSchema)
+ }
+ }
+
+ protected def newMutableProjection(
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): () => MutableProjection = {
+ log.debug(
+ s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if(codegenEnabled) {
+ GenerateMutableProjection(expressions, inputSchema)
+ } else {
+ () => new InterpretedMutableProjection(expressions, inputSchema)
+ }
+ }
+
+
+ protected def newPredicate(
+ expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
+ if (codegenEnabled) {
+ GeneratePredicate(expression, inputSchema)
+ } else {
+ InterpretedPredicate(expression, inputSchema)
+ }
+ }
+
+ protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
+ if (codegenEnabled) {
+ GenerateOrdering(order, inputSchema)
+ } else {
+ new RowOrdering(order, inputSchema)
+ }
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 404d48ae05..5f1fe99f75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.util.Try
-
import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
@@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
- planLater(left), planLater(right), condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), condition) :: Nil
case _ => Nil
}
}
@@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* will instead be used to decide the build side in a [[execution.ShuffledHashJoin]].
*/
object HashJoin extends Strategy with PredicateHelper {
+
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition: Option[Expression],
side: BuildSide) = {
val broadcastHashJoin = execution.BroadcastHashJoin(
- leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
+ leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if Try(sqlContext.autoBroadcastJoinThreshold > 0 &&
- left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) =>
+ if sqlContext.autoBroadcastJoinThreshold > 0 &&
+ left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
val buildSide =
- if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) {
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
BuildRight
} else {
BuildLeft
@@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object PartialAggregation extends Strategy {
+ object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
- // Collect all aggregate expressions.
- val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
- // Collect all aggregate expressions that can be computed partially.
- val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
-
- // Only do partial aggregation if supported by all aggregate expressions.
- if (allAggregates.size == partialAggregates.size) {
- // Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[Long, SplitEvaluation] =
- partialAggregates.map(a => (a.id, a.asPartial)).toMap
-
- // We need to pass all grouping expressions though so the grouping can happen a second
- // time. However some of them might be unnamed so we alias them allowing them to be
- // referenced in the second aggregation.
- val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
- case n: NamedExpression => (n, n)
- case other => (other, Alias(other, "PartialGroup")())
- }.toMap
+ // Aggregations that can be performed in two phases, before and after the shuffle.
- // Replace aggregations with a new expression that computes the result from the already
- // computed partial evaluations and grouping values.
- val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
- case e: Expression if partialEvaluations.contains(e.id) =>
- partialEvaluations(e.id).finalEvaluation
- case e: Expression if namedGroupingExpressions.contains(e) =>
- namedGroupingExpressions(e).toAttribute
- }).asInstanceOf[Seq[NamedExpression]]
-
- val partialComputation =
- (namedGroupingExpressions.values ++
- partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
-
- // Construct two phased aggregation.
- execution.Aggregate(
+ // Cases where all aggregates can be codegened.
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child)
+ if canBeCodeGened(
+ allAggregates(partialComputation) ++
+ allAggregates(rewrittenAggregateExpressions)) &&
+ codegenEnabled =>
+ execution.GeneratedAggregate(
partial = false,
- namedGroupingExpressions.values.map(_.toAttribute).toSeq,
+ namedGroupingAttributes,
rewrittenAggregateExpressions,
- execution.Aggregate(
+ execution.GeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
- planLater(child))(sqlContext))(sqlContext) :: Nil
- } else {
- Nil
- }
+ planLater(child))) :: Nil
+
+ // Cases where some aggregate can not be codegened
+ case PartialAggregation(
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child) =>
+ execution.Aggregate(
+ partial = false,
+ namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ execution.Aggregate(
+ partial = true,
+ groupingExpressions,
+ partialComputation,
+ planLater(child))) :: Nil
+
case _ => Nil
}
+
+ def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
+ case _: Sum | _: Count => false
+ case _ => true
+ }
+
+ def allAggregates(exprs: Seq[Expression]) =
+ exprs.flatMap(_.collect { case a: AggregateExpression => a })
}
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
execution.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
+ planLater(left), planLater(right), joinType, condition) :: Nil
case _ => Nil
}
}
@@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
- def convertToCatalyst(a: Any): Any = a match {
- case s: Seq[Any] => s.map(convertToCatalyst)
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case other => other
- }
-
object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
+ execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
}
@@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// TODO: need to support writing to other types of files. Unify the below code paths.
case logical.WriteToFile(path, child) =>
val relation =
- ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
+ ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext)
// Note: overwrite=false because otherwise the metadata we just created will be deleted
- InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil
+ InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
- InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
+ InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
@@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
projectList,
filters,
prunePushedDownFilters,
- ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
+ ParquetTableScan(_, relation, filters)) :: Nil
case _ => Nil
}
@@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Filter(condition, child) =>
execution.Filter(condition, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
- execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil
+ execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
- val dataAsRdd =
- sparkContext.parallelize(data.map(r =>
- new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
- execution.ExistingRdd(output, dataAsRdd) :: Nil
+ ExistingRdd(
+ output,
+ ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child))(sqlContext) :: Nil
+ execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
- execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
- case logical.Except(left,right) =>
- execution.Except(planLater(left),planLater(right)) :: Nil
+ execution.Union(unionChildren.map(planLater)) :: Nil
+ case logical.Except(left, right) =>
+ execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 966d8f95fc..174eda8f1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
override def output = projectList.map(_.toAttribute)
- override def execute() = child.execute().mapPartitions { iter =>
- @transient val reusableProjection = new MutableProjection(projectList)
- iter.map(reusableProjection)
+ @transient lazy val buildProjection = newMutableProjection(projectList, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ val resuableProjection = buildProjection()
+ iter.map(resuableProjection)
}
}
@@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output = child.output
- override def execute() = child.execute().mapPartitions { iter =>
- iter.filter(condition.eval(_).asInstanceOf[Boolean])
+ @transient lazy val conditionEvaluator = newPredicate(condition, child.output)
+
+ def execute() = child.execute().mapPartitions { iter =>
+ iter.filter(conditionEvaluator)
}
}
@@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child:
* :: DeveloperApi ::
*/
@DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
+case class Union(children: Seq[SparkPlan]) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
override def output = children.head.output
- override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
-
- override def otherCopyArgs = sqlContext :: Nil
+ override def execute() = sparkContext.union(children.map(_.execute()))
}
/**
@@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex
* repartition all the data to a single partition to compute the global limit.
*/
@DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
+case class Limit(limit: Int, child: SparkPlan)
extends UnaryNode {
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
// partition local limit -> exchange into one partition -> partition local limit again
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = child.output
/**
@@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
* Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sqlContext: SQLContext) extends UnaryNode {
- override def otherCopyArgs = sqlContext :: Nil
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
override def output = child.output
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
+ val ordering = new RowOrdering(sortOrder, child.output)
+ // TODO: Is this copying for no reason?
override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
+ override def execute() = sparkContext.makeRDD(executeCollect(), 1)
}
/**
@@ -189,15 +187,13 @@ case class Sort(
override def requiredChildDistribution =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- @transient
- lazy val ordering = new RowOrdering(sortOrder)
override def execute() = attachTree(this, "sort") {
- // TODO: Optimize sorting operation?
child.execute()
- .mapPartitions(
- iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator,
- preservesPartitioning = true)
+ .mapPartitions( { iterator =>
+ val ordering = newOrdering(sortOrder, child.output)
+ iterator.map(_.copy()).toArray.sorted(ordering).iterator
+ }, preservesPartitioning = true)
}
override def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index c6fbd6d2f6..5ef46c32d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -41,13 +41,13 @@ package object debug {
*/
@DeveloperApi
implicit class DebugQuery(query: SchemaRDD) {
- def debug(implicit sc: SparkContext): Unit = {
+ def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[Long]()
val debugPlan = plan transform {
case s: SparkPlan if !visited.contains(s.id) =>
visited += s.id
- DebugNode(sc, s)
+ DebugNode(s)
}
println(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
@@ -57,9 +57,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(
- @transient sparkContext: SparkContext,
- child: SparkPlan) extends UnaryNode {
+ private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
def references = Set.empty
def output = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 7d1f11caae..2750ddbce8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide
trait HashJoin {
+ self: SparkPlan =>
+
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
@@ -56,9 +58,9 @@ trait HashJoin {
def output = left.output ++ right.output
- @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
- () => new MutableProjection(streamedKeys, streamedPlan.output)
+ newMutableProjection(streamedKeys, streamedPlan.output)
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.
@@ -217,9 +219,8 @@ case class BroadcastHashJoin(
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
- right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {
+ right: SparkPlan) extends BinaryNode with HashJoin {
- override def otherCopyArgs = sqlContext :: Nil
override def outputPartitioning: Partitioning = left.outputPartitioning
@@ -228,7 +229,7 @@ case class BroadcastHashJoin(
@transient
lazy val broadcastFuture = future {
- sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
+ sparkContext.broadcast(buildPlan.executeCollect())
}
def execute() = {
@@ -248,14 +249,11 @@ case class BroadcastHashJoin(
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
def output = left.output
/** The Streamed Relation */
@@ -271,7 +269,7 @@ case class LeftSemiJoinBNL(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
@@ -300,8 +298,14 @@ case class LeftSemiJoinBNL(
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
def output = left.output ++ right.output
- def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map {
- case (l: Row, r: Row) => buildRow(l ++ r)
+ def execute() = {
+ val leftResults = left.execute().map(_.copy())
+ val rightResults = right.execute().map(_.copy())
+
+ leftResults.cartesian(rightResults).mapPartitions { iter =>
+ val joinedRow = new JoinedRow
+ iter.map(r => joinedRow(r._1, r._2))
+ }
}
}
@@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
@DeveloperApi
case class BroadcastNestedLoopJoin(
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
- (@transient sqlContext: SQLContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
- override def otherCopyArgs = sqlContext :: Nil
-
override def output = {
joinType match {
case LeftOuter =>
@@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin(
def execute() = {
val broadcastedRelation =
- sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
+ val rightNulls = new GenericMutableRow(right.output.size)
streamedIter.foreach { streamedRow =>
var i = 0
@@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin(
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
- matchedRows += buildRow(streamedRow ++ broadcastedRow)
+ matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
matched = true
includedBroadcastTuples += i
}
@@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin(
}
if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
- matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null))
+ matchedRows += joinedRow(streamedRow, rightNulls).copy()
}
}
Iterator((matchedRows, includedBroadcastTuples))
@@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin(
streamedPlusMatches.map(_._2).reduce(_ ++ _)
}
+ val leftNulls = new GenericMutableRow(left.output.size)
val rightOuterMatches: Seq[Row] =
if (joinType == RightOuter || joinType == FullOuter) {
broadcastedRelation.value.zipWithIndex.filter {
case (row, i) => !allIncludedBroadcastTuples.contains(i)
}.map {
- // TODO: Use projection.
- case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row)
+ case (row, _) => new JoinedRow(leftNulls, row)
}
} else {
Vector()
}
// TODO: Breaks lineage.
- sqlContext.sparkContext.union(
- streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
+ sparkContext.union(
+ streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 8c7dbd5eb4..b3bae5db0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
*/
private[sql] case class ParquetRelation(
path: String,
- @transient conf: Option[Configuration] = None)
+ @transient conf: Option[Configuration],
+ @transient sqlContext: SQLContext)
extends LeafNode with MultiInstanceRelation {
self: Product =>
@@ -61,7 +62,7 @@ private[sql] case class ParquetRelation(
/** Attributes */
override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf)
- override def newInstance = ParquetRelation(path).asInstanceOf[this.type]
+ override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
// different instances of the same relation,
@@ -70,6 +71,9 @@ private[sql] case class ParquetRelation(
p.path == path && p.output == output
case _ => false
}
+
+ // TODO: Use data from the footers.
+ override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes)
}
private[sql] object ParquetRelation {
@@ -106,13 +110,14 @@ private[sql] object ParquetRelation {
*/
def create(pathString: String,
child: LogicalPlan,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
if (!child.resolved) {
throw new UnresolvedException[LogicalPlan](
child,
"Attempt to create Parquet table from unresolved child (when schema is not available)")
}
- createEmpty(pathString, child.output, false, conf)
+ createEmpty(pathString, child.output, false, conf, sqlContext)
}
/**
@@ -127,14 +132,15 @@ private[sql] object ParquetRelation {
def createEmpty(pathString: String,
attributes: Seq[Attribute],
allowExisting: Boolean,
- conf: Configuration): ParquetRelation = {
+ conf: Configuration,
+ sqlContext: SQLContext): ParquetRelation = {
val path = checkPath(pathString, allowExisting, conf)
if (conf.get(ParquetOutputFormat.COMPRESSION) == null) {
conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name())
}
ParquetRelation.enableLogForwarding()
ParquetTypesConverter.writeMetaData(attributes, path, conf)
- new ParquetRelation(path.toString, Some(conf)) {
+ new ParquetRelation(path.toString, Some(conf), sqlContext) {
override val output = attributes
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index ea74320d06..912a9f002b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -55,8 +55,7 @@ case class ParquetTableScan(
// https://issues.apache.org/jira/browse/SPARK-1367
output: Seq[Attribute],
relation: ParquetRelation,
- columnPruningPred: Seq[Expression])(
- @transient val sqlContext: SQLContext)
+ columnPruningPred: Seq[Expression])
extends LeafNode {
override def execute(): RDD[Row] = {
@@ -99,8 +98,6 @@ case class ParquetTableScan(
.filter(_ != null) // Parquet's record filters may produce null values
}
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Applies a (candidate) projection.
*
@@ -110,7 +107,7 @@ case class ParquetTableScan(
def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
val success = validateProjection(prunedAttributes)
if (success) {
- ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
+ ParquetTableScan(prunedAttributes, relation, columnPruningPred)
} else {
sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
this
@@ -150,8 +147,7 @@ case class ParquetTableScan(
case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
- overwrite: Boolean = false)(
- @transient val sqlContext: SQLContext)
+ overwrite: Boolean = false)
extends UnaryNode with SparkHadoopMapReduceUtil {
/**
@@ -171,7 +167,7 @@ case class InsertIntoParquetTable(
val writeSupport =
if (child.output.map(_.dataType).forall(_.isPrimitive)) {
- logger.debug("Initializing MutableRowWriteSupport")
+ log.debug("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
classOf[org.apache.spark.sql.parquet.RowWriteSupport]
@@ -203,8 +199,6 @@ case class InsertIntoParquetTable(
override def output = child.output
- override def otherCopyArgs = sqlContext :: Nil
-
/**
* Stores the given Row RDD as a Hadoop file.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index d4599da711..837ea7695d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+import org.apache.spark.sql.test.TestSQLContext
import parquet.example.data.{GroupWriter, Group}
import parquet.example.data.simple.SimpleGroup
@@ -103,7 +104,7 @@ private[sql] object ParquetTestData {
val testDir = Utils.createTempDir()
val testFilterDir = Utils.createTempDir()
- lazy val testData = new ParquetRelation(testDir.toURI.toString)
+ lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext)
val testNestedSchema1 =
// based on blogpost example, source:
@@ -202,8 +203,10 @@ private[sql] object ParquetTestData {
val testNestedDir3 = Utils.createTempDir()
val testNestedDir4 = Utils.createTempDir()
- lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString)
- lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString)
+ lazy val testNestedData1 =
+ new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext)
+ lazy val testNestedData2 =
+ new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext)
def writeFile() = {
testDir.delete()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8e1e1971d9..1fd8d27b34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -45,6 +45,7 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution}
|== Exception ==
|$e
+ |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 215618e852..76b1724471 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite {
test("count is partially aggregated") {
val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
- val planned = PartialAggregation(query).head
- val aggregations = planned.collect { case a: Aggregate => a }
+ val planned = HashAggregation(query).head
+ val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggregations.size === 2)
}
test("count distinct is not partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
test("mixed aggregates are not partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
- val planned = PartialAggregation(query)
+ val planned = HashAggregation(query)
assert(planned.isEmpty)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index e55648b8ed..2cab5e0c44 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._
* Note: this is only a rough example of how TGFs can be expressed, the final version will likely
* involve a lot more sugar for cleaner use in Scala/Java/etc.
*/
-case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator {
+case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator {
def children = input
protected def makeOutput() = 'nameAndAge.string :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 3c911e9a4e..561f5b4a49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
+
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser}
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType}
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.util.Utils
@@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
test("Projection of simple Parquet file") {
+ SparkPlan.currentContext.set(TestSQLContext)
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- Seq())(TestSQLContext)
+ Seq())
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 84d43eaeea..f0a61270da 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -231,7 +231,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveTableScans,
DataSinks,
Scripts,
- PartialAggregation,
+ HashAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index c2b0b00aa5..39033bdeac 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -131,7 +131,7 @@ case class InsertIntoHiveTable(
conf,
SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf))
- logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
+ log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)
val writer = new SparkHiveHadoopWriter(conf, fileSinkConf)
writer.preSetup()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 8258ee5fef..0c8f676e9c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -67,7 +67,7 @@ case class ScriptTransformation(
}
}
readerThread.start()
- val outputProjection = new Projection(input)
+ val outputProjection = new InterpretedProjection(input, child.output)
iter
.map(outputProjection)
// TODO: Use SerDe
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 057eb60a02..7582b4743d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf(
@transient
protected lazy val function: GenericUDTF = createFunction()
+ @transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
+ @transient
protected lazy val outputInspectors = {
val structInspector = function.initialize(inputInspectors.toArray)
structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector)
@@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf(
override def eval(input: Row): TraversableOnce[Row] = {
outputInspectors // Make sure initialized.
- val inputProjection = new Projection(children)
+ val inputProjection = new InterpretedProjection(children)
val collector = new UDTFCollector
function.setCollector(collector)
@@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction(
override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
@transient
- val inputProjection = new Projection(exprs)
+ val inputProjection = new InterpretedProjection(exprs)
def update(input: Row): Unit = {
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d
new file mode 100644
index 0000000000..00750edc07
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d
@@ -0,0 +1 @@
+3
diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2
new file mode 100644
index 0000000000..00750edc07
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2
@@ -0,0 +1 @@
+3
diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index aadfd2e900..89cc589fb8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution
import scala.util.Try
+import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.{Row, SchemaRDD}
@@ -30,6 +32,15 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {
+ createQueryTest("single case",
+ """SELECT case when true then 1 else 2 end FROM src LIMIT 1""")
+
+ createQueryTest("double case",
+ """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""")
+
+ createQueryTest("case else null",
+ """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""")
+
createQueryTest("having no references",
"SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1")