aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/pom.xml9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala468
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala98
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala219
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala80
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala69
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala4
24 files changed, 1376 insertions, 109 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 531bfddbf2..54fa96baa1 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -38,9 +38,18 @@
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
+ <artifactId>scala-compiler</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
</dependency>
<dependency>
+ <groupId>org.scalamacros</groupId>
+ <artifactId>quasiquotes_${scala.binary.version}</artifactId>
+ <version>${scala.macros.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5c8c810d91..f44521d638 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -202,7 +202,7 @@ package object dsl {
// Protobuf terminology
def required = a.withNullability(false)
- def at(ordinal: Int) = BoundReference(ordinal, a)
+ def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 9ce1f01056..a3ebec8082 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.trees
+
import org.apache.spark.sql.Logging
/**
@@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging
* to be retrieved more efficiently. However, since operations like column pruning can change
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
-case class BoundReference(ordinal: Int, baseReference: Attribute)
- extends Attribute with trees.LeafNode[Expression] {
+case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
+ extends Expression with trees.LeafNode[Expression] {
type EvaluatedType = Any
- override def nullable = baseReference.nullable
- override def dataType = baseReference.dataType
- override def exprId = baseReference.exprId
- override def qualifiers = baseReference.qualifiers
- override def name = baseReference.name
+ override def references = Set.empty
- override def newInstance = BoundReference(ordinal, baseReference.newInstance)
- override def withNullability(newNullability: Boolean) =
- BoundReference(ordinal, baseReference.withNullability(newNullability))
- override def withQualifiers(newQualifiers: Seq[String]) =
- BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
-
- override def toString = s"$baseReference:$ordinal"
+ override def toString = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
}
-/**
- * Used to denote operators that do their own binding of attributes internally.
- */
-trait NoBind { self: trees.TreeNode[_] => }
-
-class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
- import BindReferences._
-
- def apply(plan: TreeNode): TreeNode = {
- plan.transform {
- case n: NoBind => n.asInstanceOf[TreeNode]
- case leafNode if leafNode.children.isEmpty => leafNode
- case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
- bindReference(e, unaryNode.children.head.output)
- }
- }
- }
-}
-
object BindReferences extends Logging {
def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
- // TODO: This fallback is required because some operators (such as ScriptTransform)
- // produce new attributes that can't be bound. Likely the right thing to do is remove
- // this rule and require all operators to explicitly bind to the input schema that
- // they specify.
- logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
- a
+ sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
} else {
- BoundReference(ordinal, a)
+ BoundReference(ordinal, a.dataType, a.nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 2c71d2c7b3..8fc5896974 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -17,12 +17,13 @@
package org.apache.spark.sql.catalyst.expressions
+
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
+ * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-class Projection(expressions: Seq[Expression]) extends (Row => Row) {
+class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
@@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
}
/**
- * Converts a [[Row]] to another Row given a sequence of expression that define each column of th
- * new row. If the schema of the input row is specified, then the given expression will be bound to
- * that schema.
- *
- * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
- * each time an input row is added. This significantly reduces the cost of calculating the
- * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
- * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
- * and hold on to the returned [[Row]] before calling `next()`.
+ * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
+ * expressions.
+ * @param expressions a sequence of expressions that determine the value of each column of the
+ * output row.
*/
-case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
+case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
private[this] val exprArray = expressions.toArray
- private[this] val mutableRow = new GenericMutableRow(exprArray.size)
+ private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size)
def currentValue: Row = mutableRow
- def apply(input: Row): Row = {
+ override def target(row: MutableRow): MutableProjection = {
+ mutableRow = row
+ this
+ }
+
+ override def apply(input: Row): Row = {
var i = 0
while (i < exprArray.length) {
mutableRow(i) = exprArray(i).eval(input)
@@ -76,6 +77,12 @@ class JoinedRow extends Row {
private[this] var row1: Row = _
private[this] var row2: Row = _
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
def apply(r1: Row, r2: Row): Row = {
row1 = r1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 74ae723686..7470cb861b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -88,15 +88,6 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
-
- /**
- * Experimental
- *
- * Returns a mutable string builder for the specified column. A given row should return the
- * result of any mutations made to the returned buffer next time getString is called for the same
- * column.
- */
- def getStringBuilder(ordinal: Int): StringBuilder
}
/**
@@ -180,6 +171,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
values(i).asInstanceOf[String]
}
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode(): Int = {
+ var result: Int = 37
+
+ var i = 0
+ while (i < values.length) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ apply(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
+
def copy() = this
}
@@ -187,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)
- def getStringBuilder(ordinal: Int): StringBuilder = ???
-
override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 5e089f7618..acddf5e9c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
override def eval(input: Row): Any = {
children.size match {
+ case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
new file mode 100644
index 0000000000..5b398695bf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.google.common.cache.{CacheLoader, CacheBuilder}
+
+import scala.language.existentials
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A base class for generators of byte code to perform expression evaluation. Includes a set of
+ * helpers for referring to Catalyst types and building trees that perform evaluation of individual
+ * expressions.
+ */
+abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ import scala.tools.reflect.ToolBox
+
+ protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox()
+
+ protected val rowType = typeOf[Row]
+ protected val mutableRowType = typeOf[MutableRow]
+ protected val genericRowType = typeOf[GenericRow]
+ protected val genericMutableRowType = typeOf[GenericMutableRow]
+
+ protected val projectionType = typeOf[Projection]
+ protected val mutableProjectionType = typeOf[MutableProjection]
+
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ private val javaSeparator = "$"
+
+ /**
+ * Generates a class for a given input expression. Called when there is not cached code
+ * already available.
+ */
+ protected def create(in: InType): OutType
+
+ /**
+ * Canonicalizes an input expression. Used to avoid double caching expressions that differ only
+ * cosmetically.
+ */
+ protected def canonicalize(in: InType): InType
+
+ /** Binds an input expression to a given input schema */
+ protected def bind(in: InType, inputSchema: Seq[Attribute]): InType
+
+ /**
+ * A cache of generated classes.
+ *
+ * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
+ * fundamental difference is that a ConcurrentMap persists all elements that are added to it until
+ * they are explicitly removed. A Cache on the other hand is generally configured to evict entries
+ * automatically, in order to constrain its memory footprint
+ */
+ protected val cache = CacheBuilder.newBuilder()
+ .maximumSize(1000)
+ .build(
+ new CacheLoader[InType, OutType]() {
+ override def load(in: InType): OutType = globalLock.synchronized {
+ create(in)
+ }
+ })
+
+ /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
+ def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType =
+ apply(bind(expressions, inputSchema))
+
+ /** Generates the requested evaluator given already bound expression(s). */
+ def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))
+
+ /**
+ * Returns a term name that is unique within this instance of a `CodeGenerator`.
+ *
+ * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
+ * function.)
+ */
+ protected def freshName(prefix: String): TermName = {
+ newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}")
+ }
+
+ /**
+ * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
+ *
+ * @param code The sequence of statements required to evaluate the expression.
+ * @param nullTerm A term that holds a boolean value representing whether the expression evaluated
+ * to null.
+ * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
+ * valid if `nullTerm` is set to `false`.
+ * @param objectTerm A possibly boxed version of the result of evaluating this expression.
+ */
+ protected case class EvaluatedExpression(
+ code: Seq[Tree],
+ nullTerm: TermName,
+ primitiveTerm: TermName,
+ objectTerm: TermName)
+
+ /**
+ * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
+ * can be used to determine the result of evaluating the expression on an input row.
+ */
+ def expressionEvaluator(e: Expression): EvaluatedExpression = {
+ val primitiveTerm = freshName("primitiveTerm")
+ val nullTerm = freshName("nullTerm")
+ val objectTerm = freshName("objectTerm")
+
+ implicit class Evaluate1(e: Expression) {
+ def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = {
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${f(eval.primitiveTerm)}
+ """.children
+ }
+ }
+
+ implicit class Evaluate2(expressions: (Expression, Expression)) {
+
+ /**
+ * Short hand for generating binary evaluation code, which depends on two sub-evaluations of
+ * the same type. If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f a function from two primitive term names to a tree that evaluates them.
+ */
+ def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ evaluateAs(expressions._1.dataType)(f)
+
+ def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = {
+ // TODO: Right now some timestamp tests fail if we enforce this...
+ if (expressions._1.dataType != expressions._2.dataType) {
+ log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
+ }
+
+ val eval1 = expressionEvaluator(expressions._1)
+ val eval2 = expressionEvaluator(expressions._2)
+ val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
+ val $primitiveTerm: ${termForType(resultType)} =
+ if($nullTerm) {
+ ${defaultPrimitive(resultType)}
+ } else {
+ $resultCode.asInstanceOf[${termForType(resultType)}]
+ }
+ """.children : Seq[Tree]
+ }
+ }
+
+ val inputTuple = newTermName(s"i")
+
+ // TODO: Skip generation of null handling code when expression are not nullable.
+ val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ case b @ BoundReference(ordinal, dataType, nullable) =>
+ val nullValue = q"$inputTuple.isNullAt($ordinal)"
+ q"""
+ val $nullTerm: Boolean = $nullValue
+ val $primitiveTerm: ${termForType(dataType)} =
+ if($nullTerm)
+ ${defaultPrimitive(dataType)}
+ else
+ ${getColumn(inputTuple, dataType, ordinal)}
+ """.children
+
+ case expressions.Literal(null, dataType) =>
+ q"""
+ val $nullTerm = true
+ val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}]
+ """.children
+
+ case expressions.Literal(value: Boolean, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: String, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Int, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case expressions.Literal(value: Long, dataType) =>
+ q"""
+ val $nullTerm = ${value == null}
+ val $primitiveTerm: ${termForType(dataType)} = $value
+ """.children
+
+ case Cast(e @ BinaryType(), StringType) =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
+ """.children
+
+ case Cast(child @ NumericType(), IntegerType) =>
+ child.castOrNull(c => q"$c.toInt", IntegerType)
+
+ case Cast(child @ NumericType(), LongType) =>
+ child.castOrNull(c => q"$c.toLong", LongType)
+
+ case Cast(child @ NumericType(), DoubleType) =>
+ child.castOrNull(c => q"$c.toDouble", DoubleType)
+
+ case Cast(child @ NumericType(), FloatType) =>
+ child.castOrNull(c => q"$c.toFloat", IntegerType)
+
+ // Special handling required for timestamps in hive test cases since the toString function
+ // does not match the expected output.
+ case Cast(e, StringType) if e.dataType != TimestampType =>
+ val eval = expressionEvaluator(e)
+ eval.code ++
+ q"""
+ val $nullTerm = ${eval.nullTerm}
+ val $primitiveTerm =
+ if($nullTerm)
+ ${defaultPrimitive(StringType)}
+ else
+ ${eval.primitiveTerm}.toString
+ """.children
+
+ case EqualTo(e1, e2) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" }
+
+ /* TODO: Fix null semantics.
+ case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) =>
+ val eval = expressionEvaluator(e1)
+
+ val checks = list.map {
+ case expressions.Literal(v: String, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ case expressions.Literal(v: Int, dataType) =>
+ q"if(${eval.primitiveTerm} == $v) return true"
+ }
+
+ val funcName = newTermName(s"isIn${curId.getAndIncrement()}")
+
+ q"""
+ def $funcName: Boolean = {
+ ..${eval.code}
+ if(${eval.nullTerm}) return false
+ ..$checks
+ return false
+ }
+ val $nullTerm = false
+ val $primitiveTerm = $funcName
+ """.children
+ */
+
+ case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" }
+ case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" }
+ case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" }
+ case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" }
+
+ case And(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && !${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = false
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = true
+ }
+ """.children
+
+ case Or(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = false
+
+ if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) ||
+ (!${eval2.nullTerm} && ${eval2.primitiveTerm})) {
+ $nullTerm = false
+ $primitiveTerm = true
+ } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) {
+ $nullTerm = true
+ } else {
+ $nullTerm = false
+ $primitiveTerm = false
+ }
+ """.children
+
+ case Not(child) =>
+ // Uh, bad function name...
+ child.castOrNull(c => q"!$c", BooleanType)
+
+ case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" }
+ case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" }
+ case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" }
+ case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" }
+
+ case IsNotNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm}
+ """.children
+
+ case IsNull(e) =>
+ val eval = expressionEvaluator(e)
+ q"""
+ ..${eval.code}
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm}
+ """.children
+
+ case c @ Coalesce(children) =>
+ q"""
+ var $nullTerm = true
+ var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)}
+ """.children ++
+ children.map { c =>
+ val eval = expressionEvaluator(c)
+ q"""
+ if($nullTerm) {
+ ..${eval.code}
+ if(!${eval.nullTerm}) {
+ $nullTerm = false
+ $primitiveTerm = ${eval.primitiveTerm}
+ }
+ }
+ """
+ }
+
+ case i @ expressions.If(condition, trueValue, falseValue) =>
+ val condEval = expressionEvaluator(condition)
+ val trueEval = expressionEvaluator(trueValue)
+ val falseEval = expressionEvaluator(falseValue)
+
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)}
+ ..${condEval.code}
+ if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
+ ..${trueEval.code}
+ $nullTerm = ${trueEval.nullTerm}
+ $primitiveTerm = ${trueEval.primitiveTerm}
+ } else {
+ ..${falseEval.code}
+ $nullTerm = ${falseEval.nullTerm}
+ $primitiveTerm = ${falseEval.primitiveTerm}
+ }
+ """.children
+ }
+
+ // If there was no match in the partial function above, we fall back on calling the interpreted
+ // expression evaluator.
+ val code: Seq[Tree] =
+ primitiveEvaluation.lift.apply(e).getOrElse {
+ log.debug(s"No rules to generate $e")
+ val tree = reify { e }
+ q"""
+ val $objectTerm = $tree.eval(i)
+ val $nullTerm = $objectTerm == null
+ val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}]
+ """.children
+ }
+
+ EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
+ }
+
+ protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
+ dataType match {
+ case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
+ case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
+ }
+ }
+
+ protected def setColumn(
+ destinationRow: TermName,
+ dataType: DataType,
+ ordinal: Int,
+ value: TermName) = {
+ dataType match {
+ case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case _ => q"$destinationRow.update($ordinal, $value)"
+ }
+ }
+
+ protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
+ protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+
+ protected def primitiveForType(dt: DataType) = dt match {
+ case IntegerType => "Int"
+ case LongType => "Long"
+ case ShortType => "Short"
+ case ByteType => "Byte"
+ case DoubleType => "Double"
+ case FloatType => "Float"
+ case BooleanType => "Boolean"
+ case StringType => "String"
+ }
+
+ protected def defaultPrimitive(dt: DataType) = dt match {
+ case BooleanType => ru.Literal(Constant(false))
+ case FloatType => ru.Literal(Constant(-1.0.toFloat))
+ case StringType => ru.Literal(Constant("<uninit>"))
+ case ShortType => ru.Literal(Constant(-1.toShort))
+ case LongType => ru.Literal(Constant(1L))
+ case ByteType => ru.Literal(Constant(-1.toByte))
+ case DoubleType => ru.Literal(Constant(-1.toDouble))
+ case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+ case IntegerType => ru.Literal(Constant(-1))
+ case _ => ru.Literal(Constant(null))
+ }
+
+ protected def termForType(dt: DataType) = dt match {
+ case n: NativeType => n.tag
+ case _ => typeTag[Any]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
new file mode 100644
index 0000000000..a419fd7ecb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
+ * input [[Row]] for a fixed set of [[Expression Expressions]].
+ */
+object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ val mutableRowName = newTermName("mutableRow")
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
+ val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) =>
+ val evaluationCode = expressionEvaluator(e)
+
+ evaluationCode.code :+
+ q"""
+ if(${evaluationCode.nullTerm})
+ mutableRow.setNullAt($i)
+ else
+ ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)}
+ """
+ }
+
+ val code =
+ q"""
+ () => { new $mutableProjectionType {
+
+ private[this] var $mutableRowName: $mutableRowType =
+ new $genericMutableRowType(${expressions.size})
+
+ def target(row: $mutableRowType): $mutableProjectionType = {
+ $mutableRowName = row
+ this
+ }
+
+ /* Provide immutable access to the last projected row. */
+ def currentValue: $rowType = mutableRow
+
+ def apply(i: $rowType): $rowType = {
+ ..$projectionCode
+ mutableRow
+ }
+ } }
+ """
+
+ log.debug(s"code for ${expressions.mkString(",")}:\n$code")
+ toolBox.eval(code).asInstanceOf[() => MutableProjection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
new file mode 100644
index 0000000000..4211998f75
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import com.typesafe.scalalogging.slf4j.Logging
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.{StringType, NumericType}
+
+/**
+ * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of
+ * [[Expression Expressions]].
+ */
+object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
+ in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder])
+
+ protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ protected def create(ordering: Seq[SortOrder]): Ordering[Row] = {
+ val a = newTermName("a")
+ val b = newTermName("b")
+ val comparisons = ordering.zipWithIndex.map { case (order, i) =>
+ val evalA = expressionEvaluator(order.child)
+ val evalB = expressionEvaluator(order.child)
+
+ val compare = order.child.dataType match {
+ case _: NumericType =>
+ q"""
+ val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm}
+ if(comp != 0) {
+ return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"}
+ }
+ """
+ case StringType =>
+ if (order.direction == Ascending) {
+ q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})"""
+ } else {
+ q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})"""
+ }
+ }
+
+ q"""
+ i = $a
+ ..${evalA.code}
+ i = $b
+ ..${evalB.code}
+ if (${evalA.nullTerm} && ${evalB.nullTerm}) {
+ // Nothing
+ } else if (${evalA.nullTerm}) {
+ return ${if (order.direction == Ascending) q"-1" else q"1"}
+ } else if (${evalB.nullTerm}) {
+ return ${if (order.direction == Ascending) q"1" else q"-1"}
+ } else {
+ $compare
+ }
+ """
+ }
+
+ val q"class $orderingName extends $orderingType { ..$body }" = reify {
+ class SpecificOrdering extends Ordering[Row] {
+ val o = ordering
+ }
+ }.tree.children.head
+
+ val code = q"""
+ class $orderingName extends $orderingType {
+ ..$body
+ def compare(a: $rowType, b: $rowType): Int = {
+ var i: $rowType = null // Holds current row being evaluated.
+ ..$comparisons
+ return 0
+ }
+ }
+ new $orderingName()
+ """
+ logger.debug(s"Generated Ordering: $code")
+ toolBox.eval(code).asInstanceOf[Ordering[Row]]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
new file mode 100644
index 0000000000..2a0935c790
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]].
+ */
+object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in)
+
+ protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
+ BindReferences.bindReference(in, inputSchema)
+
+ protected def create(predicate: Expression): ((Row) => Boolean) = {
+ val cEval = expressionEvaluator(predicate)
+
+ val code =
+ q"""
+ (i: $rowType) => {
+ ..${cEval.code}
+ if (${cEval.nullTerm}) false else ${cEval.primitiveTerm}
+ }
+ """
+
+ log.debug(s"Generated predicate '$predicate':\n$code")
+ toolBox.eval(code).asInstanceOf[Row => Boolean]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
new file mode 100644
index 0000000000..77fa02c13d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+
+/**
+ * Generates bytecode that produces a new [[Row]] object based on a fixed set of input
+ * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom
+ * generated based on the output types of the [[Expression]] to avoid boxing of primitive values.
+ */
+object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
+ import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.universe._
+
+ protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+ in.map(ExpressionCanonicalizer(_))
+
+ protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
+ in.map(BindReferences.bindReference(_, inputSchema))
+
+ // Make Mutablility optional...
+ protected def create(expressions: Seq[Expression]): Projection = {
+ val tupleLength = ru.Literal(Constant(expressions.length))
+ val lengthDef = q"final val length = $tupleLength"
+
+ /* TODO: Configurable...
+ val nullFunctions =
+ q"""
+ private final val nullSet = new org.apache.spark.util.collection.BitSet(length)
+ final def setNullAt(i: Int) = nullSet.set(i)
+ final def isNullAt(i: Int) = nullSet.get(i)
+ """
+ */
+
+ val nullFunctions =
+ q"""
+ private[this] var nullBits = new Array[Boolean](${expressions.size})
+ final def setNullAt(i: Int) = { nullBits(i) = true }
+ final def isNullAt(i: Int) = nullBits(i)
+ """.children
+
+ val tupleElements = expressions.zipWithIndex.flatMap {
+ case (e, i) =>
+ val elementName = newTermName(s"c$i")
+ val evaluatedExpression = expressionEvaluator(e)
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _
+ {
+ ..${evaluatedExpression.code}
+ if(${evaluatedExpression.nullTerm})
+ setNullAt($iLit)
+ else
+ $elementName = ${evaluatedExpression.primitiveTerm}
+ }
+ """.children : Seq[Tree]
+ }
+
+ val iteratorFunction = {
+ val allColumns = (0 until expressions.size).map { i =>
+ val iLit = ru.Literal(Constant(i))
+ q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }"
+ }
+ q"final def iterator = Iterator[Any](..$allColumns)"
+ }
+
+ val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)"""
+ val applyFunction = {
+ val cases = (0 until expressions.size).map { i =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }"
+ }
+ q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }"
+ }
+
+ val updateFunction = {
+ val cases = expressions.zipWithIndex.map {case (e, i) =>
+ val ordinal = ru.Literal(Constant(i))
+ val elementName = newTermName(s"c$i")
+ val iLit = ru.Literal(Constant(i))
+
+ q"""
+ if(i == $ordinal) {
+ if(value == null) {
+ setNullAt(i)
+ } else {
+ $elementName = value.asInstanceOf[${termForType(e.dataType)}]
+ return
+ }
+ }"""
+ }
+ q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
+ }
+
+ val specificAccessorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) return $elementName" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val specificMutatorFunctions = NativeType.all.map { dataType =>
+ val ifStatements = expressions.zipWithIndex.flatMap {
+ case (e, i) if e.dataType == dataType =>
+ val elementName = newTermName(s"c$i")
+ // TODO: The string of ifs gets pretty inefficient as the row grows in size.
+ // TODO: Optional null checks?
+ q"if(i == $i) { $elementName = value; return }" :: Nil
+ case _ => Nil
+ }
+
+ q"""
+ final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = {
+ ..$ifStatements;
+ $accessorFailure
+ }"""
+ }
+
+ val hashValues = expressions.zipWithIndex.map { case (e,i) =>
+ val elementName = newTermName(s"c$i")
+ val nonNull = e.dataType match {
+ case BooleanType => q"if ($elementName) 0 else 1"
+ case ByteType | ShortType | IntegerType => q"$elementName.toInt"
+ case LongType => q"($elementName ^ ($elementName >>> 32)).toInt"
+ case FloatType => q"java.lang.Float.floatToIntBits($elementName)"
+ case DoubleType =>
+ q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }"
+ case _ => q"$elementName.hashCode"
+ }
+ q"if (isNullAt($i)) 0 else $nonNull"
+ }
+
+ val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree)
+
+ val hashCodeFunction =
+ q"""
+ override def hashCode(): Int = {
+ var result: Int = 37
+ ..$hashUpdates
+ result
+ }
+ """
+
+ val columnChecks = (0 until expressions.size).map { i =>
+ val elementName = newTermName(s"c$i")
+ q"if (this.$elementName != specificType.$elementName) return false"
+ }
+
+ val equalsFunction =
+ q"""
+ override def equals(other: Any): Boolean = other match {
+ case specificType: SpecificRow =>
+ ..$columnChecks
+ return true
+ case other => super.equals(other)
+ }
+ """
+
+ val copyFunction =
+ q"""
+ final def copy() = new $genericRowType(this.toArray)
+ """
+
+ val classBody =
+ nullFunctions ++ (
+ lengthDef +:
+ iteratorFunction +:
+ applyFunction +:
+ updateFunction +:
+ equalsFunction +:
+ hashCodeFunction +:
+ copyFunction +:
+ (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions))
+
+ val code = q"""
+ final class SpecificRow(i: $rowType) extends $mutableRowType {
+ ..$classBody
+ }
+
+ new $projectionType { def apply(r: $rowType) = new SpecificRow(r) }
+ """
+
+ log.debug(
+ s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}")
+ toolBox.eval(code).asInstanceOf[Projection]
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
new file mode 100644
index 0000000000..80c7dfd376
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.rules
+import org.apache.spark.sql.catalyst.util
+
+/**
+ * A collection of generators that build custom bytecode at runtime for performing the evaluation
+ * of catalyst expression.
+ */
+package object codegen {
+
+ /**
+ * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala
+ * 2.10.
+ */
+ protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+ /** Canonicalizes an expression so those that differ only by names can reuse the same code. */
+ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
+ val batches =
+ Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil
+
+ object CleanExpressions extends rules.Rule[Expression] {
+ def apply(e: Expression): Expression = e transform {
+ case Alias(c, _) => c
+ }
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Dumps the bytecode from a class to the screen using javap.
+ */
+ @DeveloperApi
+ object DumpByteCode {
+ import scala.sys.process._
+ val dumpDirectory = util.getTempFilePath("sparkSqlByteCode")
+ dumpDirectory.mkdir()
+
+ def apply(obj: Any): Unit = {
+ val generatedClass = obj.getClass
+ val classLoader =
+ generatedClass
+ .getClassLoader
+ .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader]
+ val generatedBytes = classLoader.classBytes(generatedClass.getName)
+
+ val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName)
+ if (!packageDir.exists()) { packageDir.mkdir() }
+
+ val classFile =
+ new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class")
+
+ val outfile = new java.io.FileOutputStream(classFile)
+ outfile.write(generatedBytes)
+ outfile.close()
+
+ println(
+ s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index b6f2451b52..55d95991c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst
* ==Evaluation==
* The result of expressions can be evaluated using the `Expression.apply(Row)` method.
*/
-package object expressions
+package object expressions {
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ */
+ abstract class Projection extends (Row => Row)
+
+ /**
+ * Converts a [[Row]] to another Row given a sequence of expression that define each column of the
+ * new row. If the schema of the input row is specified, then the given expression will be bound
+ * to that schema.
+ *
+ * In contrast to a normal projection, a MutableProjection reuses the same underlying row object
+ * each time an input row is added. This significantly reduces the cost of calculating the
+ * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()`
+ * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
+ * and hold on to the returned [[Row]] before calling `next()`.
+ */
+ abstract class MutableProjection extends Projection {
+ def currentValue: Row
+
+ /** Uses the given row to store the output of the projection. */
+ def target(row: MutableRow): MutableProjection
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 06b94a98d3..5976b0ddf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType
object InterpretedPredicate {
+ def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
+ apply(BindReferences.bindReference(expression, inputSchema))
+
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.eval(r).asInstanceOf[Boolean]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
new file mode 100644
index 0000000000..3b3e206055
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+package object catalyst {
+ /**
+ * A JVM-global lock that should be used to prevent thread safety issues when using things in
+ * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
+ * 2.10.* builds. See SI-6240 for more details.
+ */
+ protected[catalyst] object ScalaReflectionLock
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 026692abe0..418f8686bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -105,6 +105,77 @@ object PhysicalOperation extends PredicateHelper {
}
/**
+ * Matches a logical aggregation that can be performed on distributed data in two steps. The first
+ * operates on the data in each partition performing partial aggregation for each group. The second
+ * occurs after the shuffle and completes the aggregation.
+ *
+ * This pattern will only match if all aggregate expressions can be computed partially and will
+ * return the rewritten aggregation expressions for both phases.
+ *
+ * The returned values for this match are as follows:
+ * - Grouping attributes for the final aggregation.
+ * - Aggregates for the final aggregation.
+ * - Grouping expressions for the partial aggregation.
+ * - Partial aggregate expressions.
+ * - Input to the aggregation.
+ */
+object PartialAggregation {
+ type ReturnType =
+ (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
+ case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ // Collect all aggregate expressions.
+ val allAggregates =
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ // Collect all aggregate expressions that can be computed partially.
+ val partialAggregates =
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+
+ // Only do partial aggregation if supported by all aggregate expressions.
+ if (allAggregates.size == partialAggregates.size) {
+ // Create a map of expressions to their partial evaluations for all aggregate expressions.
+ val partialEvaluations: Map[Long, SplitEvaluation] =
+ partialAggregates.map(a => (a.id, a.asPartial)).toMap
+
+ // We need to pass all grouping expressions though so the grouping can happen a second
+ // time. However some of them might be unnamed so we alias them allowing them to be
+ // referenced in the second aggregation.
+ val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map {
+ case n: NamedExpression => (n, n)
+ case other => (other, Alias(other, "PartialGroup")())
+ }.toMap
+
+ // Replace aggregations with a new expression that computes the result from the already
+ // computed partial evaluations and grouping values.
+ val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+ case e: Expression if partialEvaluations.contains(e.id) =>
+ partialEvaluations(e.id).finalEvaluation
+ case e: Expression if namedGroupingExpressions.contains(e) =>
+ namedGroupingExpressions(e).toAttribute
+ }).asInstanceOf[Seq[NamedExpression]]
+
+ val partialComputation =
+ (namedGroupingExpressions.values ++
+ partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+
+ val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
+
+ Some(
+ (namedGroupingAttributes,
+ rewrittenAggregateExpressions,
+ groupingExpressions,
+ partialComputation,
+ child))
+ } else {
+ None
+ }
+ case _ => None
+ }
+}
+
+
+/**
* A pattern that finds joins with equality conditions that can be evaluated using equi-join.
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ac85f95b52..888cb08e95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -112,7 +112,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
self: Product =>
override lazy val statistics: Statistics =
- throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics")
+ throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
// Leaf nodes by definition cannot reference any input attributes.
override def references = Set.empty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index a357c6ffb8..481a5a4f21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -35,7 +35,7 @@ abstract class Command extends LeafNode {
*/
case class NativeCommand(cmd: String) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)()))
+ Seq(AttributeReference("result", StringType, nullable = false)())
}
/**
@@ -43,7 +43,7 @@ case class NativeCommand(cmd: String) extends Command {
*/
case class SetCommand(key: Option[String], value: Option[String]) extends Command {
override def output = Seq(
- BoundReference(1, AttributeReference("", StringType, nullable = false)()))
+ AttributeReference("", StringType, nullable = false)())
}
/**
@@ -52,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman
*/
case class ExplainCommand(plan: LogicalPlan) extends Command {
override def output =
- Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)()))
+ Seq(AttributeReference("plan", StringType, nullable = false)())
}
/**
@@ -71,7 +71,7 @@ case class DescribeCommand(
isExtended: Boolean) extends Command {
override def output = Seq(
// Column names are based on Hive.
- BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()),
- BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()),
- BoundReference(2, AttributeReference("comment", StringType, nullable = false)()))
+ AttributeReference("col_name", StringType, nullable = false)(),
+ AttributeReference("data_type", StringType, nullable = false)(),
+ AttributeReference("comment", StringType, nullable = false)())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index e32adb76fe..e300bdbece 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -72,7 +72,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
}
iteration += 1
if (iteration > batch.strategy.maxIterations) {
- logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}")
+ // Only log if this is a rule that is supposed to run more than once.
+ if (iteration != 2) {
+ logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
+ }
continue = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index cd4b5e9c1b..71808f76d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -23,16 +23,13 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
import scala.util.parsing.combinator.RegexParsers
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.util.Utils
/**
- * A JVM-global lock that should be used to prevent thread safety issues when using things in
- * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for
- * 2.10.* builds. See SI-6240 for more details.
+ * Utility functions for working with DataTypes.
*/
-protected[catalyst] object ScalaReflectionLock
-
object DataType extends RegexParsers {
protected lazy val primitiveType: Parser[DataType] =
"StringType" ^^^ StringType |
@@ -99,6 +96,13 @@ abstract class DataType {
case object NullType extends DataType
+object NativeType {
+ def all = Seq(
+ IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+
+ def unapply(dt: DataType): Boolean = all.contains(dt)
+}
+
trait PrimitiveType extends DataType {
override def isPrimitive = true
}
@@ -149,6 +153,10 @@ abstract class NumericType extends NativeType with PrimitiveType {
val numeric: Numeric[JvmType]
}
+object NumericType {
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
+}
+
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
object IntegralType {
def unapply(a: Expression): Boolean = a match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 58f8c341e6..999c9fff38 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
class ExpressionEvaluationSuite extends FunSuite {
test("literals") {
- assert((Literal(1) + Literal(1)).eval(null) === 2)
+ checkEvaluation(Literal(1), 1)
+ checkEvaluation(Literal(true), true)
+ checkEvaluation(Literal(0L), 0L)
+ checkEvaluation(Literal("test"), "test")
+ checkEvaluation(Literal(1) + Literal(1), 2)
}
/**
@@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite {
test("3VL Not") {
notTrueTable.foreach {
case (v, answer) =>
- val expr = ! Literal(v, BooleanType)
- val result = expr.eval(null)
- if (result != answer)
- fail(s"$expr should not evaluate to $result, expected: $answer") }
+ checkEvaluation(!Literal(v, BooleanType), answer)
+ }
}
booleanLogicTest("AND", _ && _,
@@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite {
}
}
+ test("IN") {
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
@@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
- checkEvaluation("23" cast DoubleType, 23)
+ checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
- checkEvaluation("23" cast FloatType, 23)
- checkEvaluation("23" cast DecimalType, 23)
- checkEvaluation("23" cast ByteType, 23)
- checkEvaluation("23" cast ShortType, 23)
+ checkEvaluation("23" cast FloatType, 23f)
+ checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast ByteType, 23.toByte)
+ checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)
- checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+ checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
- checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
- checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
- checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+ checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
+ checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
+ checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
@@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite {
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
+ checkEvaluation(GetItem(BoundReference(3, typeMap, true),
Literal(null, StringType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
- checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
+ checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(null, IntegerType)), null, row)
- checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
+ checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
val typeS_notNullable = StructType(
@@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite {
:: StructField("b", StringType, nullable = false) :: Nil
)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS)()), "a").nullable === true)
- assert(GetField(BoundReference(2,
- AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false)
+ assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
+ assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
new file mode 100644
index 0000000000..245a2e1480
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use code generation for evaluation.
+ */
+class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ val plan = try {
+ GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)()
+ } catch {
+ case e: Throwable =>
+ val evaluated = GenerateProjection.expressionEvaluator(expression)
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow).apply(0)
+ if(actual != expected) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+
+
+ test("multithreaded eval") {
+ import scala.concurrent._
+ import ExecutionContext.Implicits.global
+ import scala.concurrent.duration._
+
+ val futures = (1 to 20).map { _ =>
+ future {
+ GeneratePredicate(EqualTo(Literal(1), Literal(1)))
+ GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil)
+ }
+ }
+
+ futures.foreach(Await.result(_, 10.seconds))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
new file mode 100644
index 0000000000..887aabb1d5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+/**
+ * Overrides our expression evaluation tests to use generated code on mutable rows.
+ */
+class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
+ override def checkEvaluation(
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
+ lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
+
+ val plan = try {
+ GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
+ } catch {
+ case e: Throwable =>
+ fail(
+ s"""
+ |Code generation of $expression failed:
+ |${evaluated.code.mkString("\n")}
+ |$e
+ """.stripMargin)
+ }
+
+ val actual = plan(inputRow)
+ val expectedRow = new GenericRow(Array[Any](expected))
+ if (actual.hashCode() != expectedRow.hashCode()) {
+ fail(
+ s"""
+ |Mismatched hashCodes for values: $actual, $expectedRow
+ |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
+ |${evaluated.code.mkString("\n")}
+ """.stripMargin)
+ }
+ if (actual != expectedRow) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index 4896f1b955..e2ae0d25db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("Combine Limit", FixedPoint(2),
+ Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
- Batch("Constant Folding", FixedPoint(3),
+ Batch("Constant Folding", FixedPoint(10),
NullPropagation,
ConstantFolding,
BooleanSimplification) :: Nil