diff options
Diffstat (limited to 'sql')
51 files changed, 1871 insertions, 294 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 531bfddbf2..54fa96baa1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -38,9 +38,18 @@ <dependencies> <dependency> <groupId>org.scala-lang</groupId> + <artifactId>scala-compiler</artifactId> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> <artifactId>scala-reflect</artifactId> </dependency> <dependency> + <groupId>org.scalamacros</groupId> + <artifactId>quasiquotes_${scala.binary.version}</artifactId> + <version>${scala.macros.version}</version> + </dependency> + <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5c8c810d91..f44521d638 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -202,7 +202,7 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) - def at(ordinal: Int) = BoundReference(ordinal, a) + def at(ordinal: Int) = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f01056..a3ebec8082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.trees + import org.apache.spark.sql.Logging /** @@ -28,61 +30,27 @@ import org.apache.spark.sql.Logging * to be retrieved more efficiently. However, since operations like column pruning can change * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ -case class BoundReference(ordinal: Int, baseReference: Attribute) - extends Attribute with trees.LeafNode[Expression] { +case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) + extends Expression with trees.LeafNode[Expression] { type EvaluatedType = Any - override def nullable = baseReference.nullable - override def dataType = baseReference.dataType - override def exprId = baseReference.exprId - override def qualifiers = baseReference.qualifiers - override def name = baseReference.name + override def references = Set.empty - override def newInstance = BoundReference(ordinal, baseReference.newInstance) - override def withNullability(newNullability: Boolean) = - BoundReference(ordinal, baseReference.withNullability(newNullability)) - override def withQualifiers(newQualifiers: Seq[String]) = - BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) - - override def toString = s"$baseReference:$ordinal" + override def toString = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) } -/** - * Used to denote operators that do their own binding of attributes internally. - */ -trait NoBind { self: trees.TreeNode[_] => } - -class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { - import BindReferences._ - - def apply(plan: TreeNode): TreeNode = { - plan.transform { - case n: NoBind => n.asInstanceOf[TreeNode] - case leafNode if leafNode.children.isEmpty => leafNode - case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => - bindReference(e, unaryNode.children.head.output) - } - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - // TODO: This fallback is required because some operators (such as ScriptTransform) - // produce new attributes that can't be bound. Likely the right thing to do is remove - // this rule and require all operators to explicitly bind to the input schema that - // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") - a + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } else { - BoundReference(ordinal, a) + BoundReference(ordinal, a.dataType, a.nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 2c71d2c7b3..8fc5896974 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions + /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of the - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. + * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -class Projection(expressions: Seq[Expression]) extends (Row => Row) { +class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) @@ -40,25 +41,25 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } /** - * Converts a [[Row]] to another Row given a sequence of expression that define each column of th - * new row. If the schema of the input row is specified, then the given expression will be bound to - * that schema. - * - * In contrast to a normal projection, a MutableProjection reuses the same underlying row object - * each time an input row is added. This significantly reduces the cost of calculating the - * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` - * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` - * and hold on to the returned [[Row]] before calling `next()`. + * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified + * expressions. + * @param expressions a sequence of expressions that determine the value of each column of the + * output row. */ -case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { +case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] val mutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) def currentValue: Row = mutableRow - def apply(input: Row): Row = { + override def target(row: MutableRow): MutableProjection = { + mutableRow = row + this + } + + override def apply(input: Row): Row = { var i = 0 while (i < exprArray.length) { mutableRow(i) = exprArray(i).eval(input) @@ -76,6 +77,12 @@ class JoinedRow extends Row { private[this] var row1: Row = _ private[this] var row2: Row = _ + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ def apply(r1: Row, r2: Row): Row = { row1 = r1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686..7470cb861b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -88,15 +88,6 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - - /** - * Experimental - * - * Returns a mutable string builder for the specified column. A given row should return the - * result of any mutations made to the returned buffer next time getString is called for the same - * column. - */ - def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -180,6 +171,35 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + // Custom hashCode function that matches the efficient code generated version. + override def hashCode(): Int = { + var result: Int = 37 + + var i = 0 + while (i < values.length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } + def copy() = this } @@ -187,8 +207,6 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { /** No-arg constructor for serialization. */ def this() = this(0) - def getStringBuilder(ordinal: Int): StringBuilder = ??? - override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 5e089f7618..acddf5e9c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi override def eval(input: Row): Any = { children.size match { + case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala new file mode 100644 index 0000000000..5b398695bf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.google.common.cache.{CacheLoader, CacheBuilder} + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + import scala.tools.reflect.ToolBox + + protected val toolBox = runtimeMirror(getClass.getClassLoader).mkToolBox() + + protected val rowType = typeOf[Row] + protected val mutableRowType = typeOf[MutableRow] + protected val genericRowType = typeOf[GenericRow] + protected val genericMutableRowType = typeOf[GenericMutableRow] + + protected val projectionType = typeOf[Projection] + protected val mutableProjectionType = typeOf[MutableProjection] + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + private val javaSeparator = "$" + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = globalLock.synchronized { + create(in) + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = + apply(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + protected def freshName(prefix: String): TermName = { + newTermName(s"$prefix$javaSeparator${curId.getAndIncrement}") + } + + /** + * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `false`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. + */ + protected case class EvaluatedExpression( + code: Seq[Tree], + nullTerm: TermName, + primitiveTerm: TermName, + objectTerm: TermName) + + /** + * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that + * can be used to determine the result of evaluating the expression on an input row. + */ + def expressionEvaluator(e: Expression): EvaluatedExpression = { + val primitiveTerm = freshName("primitiveTerm") + val nullTerm = freshName("nullTerm") + val objectTerm = freshName("objectTerm") + + implicit class Evaluate1(e: Expression) { + def castOrNull(f: TermName => Tree, dataType: DataType): Seq[Tree] = { + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${f(eval.primitiveTerm)} + """.children + } + } + + implicit class Evaluate2(expressions: (Expression, Expression)) { + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] = + evaluateAs(expressions._1.dataType)(f) + + def evaluateAs(resultType: DataType)(f: (TermName, TermName) => Tree): Seq[Tree] = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (expressions._1.dataType != expressions._2.dataType) { + log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") + } + + val eval1 = expressionEvaluator(expressions._1) + val eval2 = expressionEvaluator(expressions._2) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code ++ eval2.code ++ + q""" + val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm} + val $primitiveTerm: ${termForType(resultType)} = + if($nullTerm) { + ${defaultPrimitive(resultType)} + } else { + $resultCode.asInstanceOf[${termForType(resultType)}] + } + """.children : Seq[Tree] + } + } + + val inputTuple = newTermName(s"i") + + // TODO: Skip generation of null handling code when expression are not nullable. + val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = { + case b @ BoundReference(ordinal, dataType, nullable) => + val nullValue = q"$inputTuple.isNullAt($ordinal)" + q""" + val $nullTerm: Boolean = $nullValue + val $primitiveTerm: ${termForType(dataType)} = + if($nullTerm) + ${defaultPrimitive(dataType)} + else + ${getColumn(inputTuple, dataType, ordinal)} + """.children + + case expressions.Literal(null, dataType) => + q""" + val $nullTerm = true + val $primitiveTerm: ${termForType(dataType)} = null.asInstanceOf[${termForType(dataType)}] + """.children + + case expressions.Literal(value: Boolean, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: String, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Int, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case expressions.Literal(value: Long, dataType) => + q""" + val $nullTerm = ${value == null} + val $primitiveTerm: ${termForType(dataType)} = $value + """.children + + case Cast(e @ BinaryType(), StringType) => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + """.children + + case Cast(child @ NumericType(), IntegerType) => + child.castOrNull(c => q"$c.toInt", IntegerType) + + case Cast(child @ NumericType(), LongType) => + child.castOrNull(c => q"$c.toLong", LongType) + + case Cast(child @ NumericType(), DoubleType) => + child.castOrNull(c => q"$c.toDouble", DoubleType) + + case Cast(child @ NumericType(), FloatType) => + child.castOrNull(c => q"$c.toFloat", IntegerType) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + val eval = expressionEvaluator(e) + eval.code ++ + q""" + val $nullTerm = ${eval.nullTerm} + val $primitiveTerm = + if($nullTerm) + ${defaultPrimitive(StringType)} + else + ${eval.primitiveTerm}.toString + """.children + + case EqualTo(e1, e2) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } + + /* TODO: Fix null semantics. + case In(e1, list) if !list.exists(!_.isInstanceOf[expressions.Literal]) => + val eval = expressionEvaluator(e1) + + val checks = list.map { + case expressions.Literal(v: String, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + case expressions.Literal(v: Int, dataType) => + q"if(${eval.primitiveTerm} == $v) return true" + } + + val funcName = newTermName(s"isIn${curId.getAndIncrement()}") + + q""" + def $funcName: Boolean = { + ..${eval.code} + if(${eval.nullTerm}) return false + ..$checks + return false + } + val $nullTerm = false + val $primitiveTerm = $funcName + """.children + */ + + case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 > $eval2" } + case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 >= $eval2" } + case LessThan(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 < $eval2" } + case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => + (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 <= $eval2" } + + case And(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && !${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && !${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = false + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = true + } + """.children + + case Or(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = false + + if ((!${eval1.nullTerm} && ${eval1.primitiveTerm}) || + (!${eval2.nullTerm} && ${eval2.primitiveTerm})) { + $nullTerm = false + $primitiveTerm = true + } else if (${eval1.nullTerm} || ${eval2.nullTerm} ) { + $nullTerm = true + } else { + $nullTerm = false + $primitiveTerm = false + } + """.children + + case Not(child) => + // Uh, bad function name... + child.castOrNull(c => q"!$c", BooleanType) + + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } + case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } + case Divide(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 / $eval2" } + + case IsNotNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = !${eval.nullTerm} + """.children + + case IsNull(e) => + val eval = expressionEvaluator(e) + q""" + ..${eval.code} + var $nullTerm = false + var $primitiveTerm: ${termForType(BooleanType)} = ${eval.nullTerm} + """.children + + case c @ Coalesce(children) => + q""" + var $nullTerm = true + var $primitiveTerm: ${termForType(c.dataType)} = ${defaultPrimitive(c.dataType)} + """.children ++ + children.map { c => + val eval = expressionEvaluator(c) + q""" + if($nullTerm) { + ..${eval.code} + if(!${eval.nullTerm}) { + $nullTerm = false + $primitiveTerm = ${eval.primitiveTerm} + } + } + """ + } + + case i @ expressions.If(condition, trueValue, falseValue) => + val condEval = expressionEvaluator(condition) + val trueEval = expressionEvaluator(trueValue) + val falseEval = expressionEvaluator(falseValue) + + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(i.dataType)} = ${defaultPrimitive(i.dataType)} + ..${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ..${trueEval.code} + $nullTerm = ${trueEval.nullTerm} + $primitiveTerm = ${trueEval.primitiveTerm} + } else { + ..${falseEval.code} + $nullTerm = ${falseEval.nullTerm} + $primitiveTerm = ${falseEval.primitiveTerm} + } + """.children + } + + // If there was no match in the partial function above, we fall back on calling the interpreted + // expression evaluator. + val code: Seq[Tree] = + primitiveEvaluation.lift.apply(e).getOrElse { + log.debug(s"No rules to generate $e") + val tree = reify { e } + q""" + val $objectTerm = $tree.eval(i) + val $nullTerm = $objectTerm == null + val $primitiveTerm = $objectTerm.asInstanceOf[${termForType(e.dataType)}] + """.children + } + + EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + } + + protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { + dataType match { + case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" + } + } + + protected def setColumn( + destinationRow: TermName, + dataType: DataType, + ordinal: Int, + value: TermName) = { + dataType match { + case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case _ => q"$destinationRow.update($ordinal, $value)" + } + } + + protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") + protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + + protected def primitiveForType(dt: DataType) = dt match { + case IntegerType => "Int" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case StringType => "String" + } + + protected def defaultPrimitive(dt: DataType) = dt match { + case BooleanType => ru.Literal(Constant(false)) + case FloatType => ru.Literal(Constant(-1.0.toFloat)) + case StringType => ru.Literal(Constant("<uninit>")) + case ShortType => ru.Literal(Constant(-1.toShort)) + case LongType => ru.Literal(Constant(1L)) + case ByteType => ru.Literal(Constant(-1.toByte)) + case DoubleType => ru.Literal(Constant(-1.toDouble)) + case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case IntegerType => ru.Literal(Constant(-1)) + case _ => ru.Literal(Constant(null)) + } + + protected def termForType(dt: DataType) = dt match { + case n: NativeType => n.tag + case _ => typeTag[Any] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala new file mode 100644 index 0000000000..a419fd7ecb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new + * input [[Row]] for a fixed set of [[Expression Expressions]]. + */ +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + val mutableRowName = newTermName("mutableRow") + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + val projectionCode = expressions.zipWithIndex.flatMap { case (e, i) => + val evaluationCode = expressionEvaluator(e) + + evaluationCode.code :+ + q""" + if(${evaluationCode.nullTerm}) + mutableRow.setNullAt($i) + else + ${setColumn(mutableRowName, e.dataType, i, evaluationCode.primitiveTerm)} + """ + } + + val code = + q""" + () => { new $mutableProjectionType { + + private[this] var $mutableRowName: $mutableRowType = + new $genericMutableRowType(${expressions.size}) + + def target(row: $mutableRowType): $mutableProjectionType = { + $mutableRowName = row + this + } + + /* Provide immutable access to the last projected row. */ + def currentValue: $rowType = mutableRow + + def apply(i: $rowType): $rowType = { + ..$projectionCode + mutableRow + } + } } + """ + + log.debug(s"code for ${expressions.mkString(",")}:\n$code") + toolBox.eval(code).asInstanceOf[() => MutableProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala new file mode 100644 index 0000000000..4211998f75 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NumericType} + +/** + * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of + * [[Expression Expressions]]. + */ +object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] with Logging { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = + in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + + protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(ordering: Seq[SortOrder]): Ordering[Row] = { + val a = newTermName("a") + val b = newTermName("b") + val comparisons = ordering.zipWithIndex.map { case (order, i) => + val evalA = expressionEvaluator(order.child) + val evalB = expressionEvaluator(order.child) + + val compare = order.child.dataType match { + case _: NumericType => + q""" + val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} + if(comp != 0) { + return ${if (order.direction == Ascending) q"comp.toInt" else q"-comp.toInt"} + } + """ + case StringType => + if (order.direction == Ascending) { + q"""return ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm})""" + } else { + q"""return ${evalB.primitiveTerm}.compare(${evalA.primitiveTerm})""" + } + } + + q""" + i = $a + ..${evalA.code} + i = $b + ..${evalB.code} + if (${evalA.nullTerm} && ${evalB.nullTerm}) { + // Nothing + } else if (${evalA.nullTerm}) { + return ${if (order.direction == Ascending) q"-1" else q"1"} + } else if (${evalB.nullTerm}) { + return ${if (order.direction == Ascending) q"1" else q"-1"} + } else { + $compare + } + """ + } + + val q"class $orderingName extends $orderingType { ..$body }" = reify { + class SpecificOrdering extends Ordering[Row] { + val o = ordering + } + }.tree.children.head + + val code = q""" + class $orderingName extends $orderingType { + ..$body + def compare(a: $rowType, b: $rowType): Int = { + var i: $rowType = null // Holds current row being evaluated. + ..$comparisons + return 0 + } + } + new $orderingName() + """ + logger.debug(s"Generated Ordering: $code") + toolBox.eval(code).asInstanceOf[Ordering[Row]] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala new file mode 100644 index 0000000000..2a0935c790 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[Row]]. + */ +object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + + protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = + BindReferences.bindReference(in, inputSchema) + + protected def create(predicate: Expression): ((Row) => Boolean) = { + val cEval = expressionEvaluator(predicate) + + val code = + q""" + (i: $rowType) => { + ..${cEval.code} + if (${cEval.nullTerm}) false else ${cEval.primitiveTerm} + } + """ + + log.debug(s"Generated predicate '$predicate':\n$code") + toolBox.eval(code).asInstanceOf[Row => Boolean] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala new file mode 100644 index 0000000000..77fa02c13d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + + +/** + * Generates bytecode that produces a new [[Row]] object based on a fixed set of input + * [[Expression Expressions]] and a given input [[Row]]. The returned [[Row]] object is custom + * generated based on the output types of the [[Expression]] to avoid boxing of primitive values. + */ +object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { + import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.universe._ + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer(_)) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + // Make Mutablility optional... + protected def create(expressions: Seq[Expression]): Projection = { + val tupleLength = ru.Literal(Constant(expressions.length)) + val lengthDef = q"final val length = $tupleLength" + + /* TODO: Configurable... + val nullFunctions = + q""" + private final val nullSet = new org.apache.spark.util.collection.BitSet(length) + final def setNullAt(i: Int) = nullSet.set(i) + final def isNullAt(i: Int) = nullSet.get(i) + """ + */ + + val nullFunctions = + q""" + private[this] var nullBits = new Array[Boolean](${expressions.size}) + final def setNullAt(i: Int) = { nullBits(i) = true } + final def isNullAt(i: Int) = nullBits(i) + """.children + + val tupleElements = expressions.zipWithIndex.flatMap { + case (e, i) => + val elementName = newTermName(s"c$i") + val evaluatedExpression = expressionEvaluator(e) + val iLit = ru.Literal(Constant(i)) + + q""" + var ${newTermName(s"c$i")}: ${termForType(e.dataType)} = _ + { + ..${evaluatedExpression.code} + if(${evaluatedExpression.nullTerm}) + setNullAt($iLit) + else + $elementName = ${evaluatedExpression.primitiveTerm} + } + """.children : Seq[Tree] + } + + val iteratorFunction = { + val allColumns = (0 until expressions.size).map { i => + val iLit = ru.Literal(Constant(i)) + q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" + } + q"final def iterator = Iterator[Any](..$allColumns)" + } + + val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" + val applyFunction = { + val cases = (0 until expressions.size).map { i => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" + } + q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + } + + val updateFunction = { + val cases = expressions.zipWithIndex.map {case (e, i) => + val ordinal = ru.Literal(Constant(i)) + val elementName = newTermName(s"c$i") + val iLit = ru.Literal(Constant(i)) + + q""" + if(i == $ordinal) { + if(value == null) { + setNullAt(i) + } else { + $elementName = value.asInstanceOf[${termForType(e.dataType)}] + return + } + }""" + } + q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + } + + val specificAccessorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) return $elementName" :: Nil + case _ => Nil + } + + q""" + final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } + + val specificMutatorFunctions = NativeType.all.map { dataType => + val ifStatements = expressions.zipWithIndex.flatMap { + case (e, i) if e.dataType == dataType => + val elementName = newTermName(s"c$i") + // TODO: The string of ifs gets pretty inefficient as the row grows in size. + // TODO: Optional null checks? + q"if(i == $i) { $elementName = value; return }" :: Nil + case _ => Nil + } + + q""" + final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + ..$ifStatements; + $accessorFailure + }""" + } + + val hashValues = expressions.zipWithIndex.map { case (e,i) => + val elementName = newTermName(s"c$i") + val nonNull = e.dataType match { + case BooleanType => q"if ($elementName) 0 else 1" + case ByteType | ShortType | IntegerType => q"$elementName.toInt" + case LongType => q"($elementName ^ ($elementName >>> 32)).toInt" + case FloatType => q"java.lang.Float.floatToIntBits($elementName)" + case DoubleType => + q"{ val b = java.lang.Double.doubleToLongBits($elementName); (b ^ (b >>>32)).toInt }" + case _ => q"$elementName.hashCode" + } + q"if (isNullAt($i)) 0 else $nonNull" + } + + val hashUpdates: Seq[Tree] = hashValues.map(v => q"""result = 37 * result + $v""": Tree) + + val hashCodeFunction = + q""" + override def hashCode(): Int = { + var result: Int = 37 + ..$hashUpdates + result + } + """ + + val columnChecks = (0 until expressions.size).map { i => + val elementName = newTermName(s"c$i") + q"if (this.$elementName != specificType.$elementName) return false" + } + + val equalsFunction = + q""" + override def equals(other: Any): Boolean = other match { + case specificType: SpecificRow => + ..$columnChecks + return true + case other => super.equals(other) + } + """ + + val copyFunction = + q""" + final def copy() = new $genericRowType(this.toArray) + """ + + val classBody = + nullFunctions ++ ( + lengthDef +: + iteratorFunction +: + applyFunction +: + updateFunction +: + equalsFunction +: + hashCodeFunction +: + copyFunction +: + (tupleElements ++ specificAccessorFunctions ++ specificMutatorFunctions)) + + val code = q""" + final class SpecificRow(i: $rowType) extends $mutableRowType { + ..$classBody + } + + new $projectionType { def apply(r: $rowType) = new SpecificRow(r) } + """ + + log.debug( + s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${toolBox.typeCheck(code)}") + toolBox.eval(code).asInstanceOf[Projection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala new file mode 100644 index 0000000000..80c7dfd376 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.rules +import org.apache.spark.sql.catalyst.util + +/** + * A collection of generators that build custom bytecode at runtime for performing the evaluation + * of catalyst expression. + */ +package object codegen { + + /** + * A lock to protect invoking the scala compiler at runtime, since it is not thread safe in Scala + * 2.10. + */ + protected[codegen] val globalLock = org.apache.spark.sql.catalyst.ScalaReflectionLock + + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ + object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { + val batches = + Batch("CleanExpressions", FixedPoint(20), CleanExpressions) :: Nil + + object CleanExpressions extends rules.Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case Alias(c, _) => c + } + } + } + + /** + * :: DeveloperApi :: + * Dumps the bytecode from a class to the screen using javap. + */ + @DeveloperApi + object DumpByteCode { + import scala.sys.process._ + val dumpDirectory = util.getTempFilePath("sparkSqlByteCode") + dumpDirectory.mkdir() + + def apply(obj: Any): Unit = { + val generatedClass = obj.getClass + val classLoader = + generatedClass + .getClassLoader + .asInstanceOf[scala.tools.nsc.interpreter.AbstractFileClassLoader] + val generatedBytes = classLoader.classBytes(generatedClass.getName) + + val packageDir = new java.io.File(dumpDirectory, generatedClass.getPackage.getName) + if (!packageDir.exists()) { packageDir.mkdir() } + + val classFile = + new java.io.File(packageDir, generatedClass.getName.split("\\.").last + ".class") + + val outfile = new java.io.FileOutputStream(classFile) + outfile.write(generatedBytes) + outfile.close() + + println( + s"javap -p -v -classpath ${dumpDirectory.getCanonicalPath} ${generatedClass.getName}".!!) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index b6f2451b52..55d95991c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -47,4 +47,30 @@ package org.apache.spark.sql.catalyst * ==Evaluation== * The result of expressions can be evaluated using the `Expression.apply(Row)` method. */ -package object expressions +package object expressions { + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + */ + abstract class Projection extends (Row => Row) + + /** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound + * to that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significantly reduces the cost of calculating the + * projection, but means that it is not safe to hold on to a reference to a [[Row]] after `next()` + * has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()` + * and hold on to the returned [[Row]] before calling `next()`. + */ + abstract class MutableProjection extends Projection { + def currentValue: Row + + /** Uses the given row to store the output of the projection. */ + def target(row: MutableRow): MutableProjection + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06b94a98d3..5976b0ddf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType object InterpretedPredicate { + def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + apply(BindReferences.bindReference(expression, inputSchema)) + def apply(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 0000000000..3b3e206055 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object catalyst { + /** + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. + */ + protected[catalyst] object ScalaReflectionLock +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe0..418f8686bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -105,6 +105,77 @@ object PhysicalOperation extends PredicateHelper { } /** + * Matches a logical aggregation that can be performed on distributed data in two steps. The first + * operates on the data in each partition performing partial aggregation for each group. The second + * occurs after the shuffle and completes the aggregation. + * + * This pattern will only match if all aggregate expressions can be computed partially and will + * return the rewritten aggregation expressions for both phases. + * + * The returned values for this match are as follows: + * - Grouping attributes for the final aggregation. + * - Aggregates for the final aggregation. + * - Grouping expressions for the partial aggregation. + * - Partial aggregate expressions. + * - Input to the aggregation. + */ +object PartialAggregation { + type ReturnType = + (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + + Some( + (namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child)) + } else { + None + } + case _ => None + } +} + + +/** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ac85f95b52..888cb08e95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -112,7 +112,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => override lazy val statistics: Statistics = - throw new UnsupportedOperationException("default leaf nodes don't have meaningful Statistics") + throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index a357c6ffb8..481a5a4f21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -35,7 +35,7 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("result", StringType, nullable = false)())) + Seq(AttributeReference("result", StringType, nullable = false)()) } /** @@ -43,7 +43,7 @@ case class NativeCommand(cmd: String) extends Command { */ case class SetCommand(key: Option[String], value: Option[String]) extends Command { override def output = Seq( - BoundReference(1, AttributeReference("", StringType, nullable = false)())) + AttributeReference("", StringType, nullable = false)()) } /** @@ -52,7 +52,7 @@ case class SetCommand(key: Option[String], value: Option[String]) extends Comman */ case class ExplainCommand(plan: LogicalPlan) extends Command { override def output = - Seq(BoundReference(0, AttributeReference("plan", StringType, nullable = false)())) + Seq(AttributeReference("plan", StringType, nullable = false)()) } /** @@ -71,7 +71,7 @@ case class DescribeCommand( isExtended: Boolean) extends Command { override def output = Seq( // Column names are based on Hive. - BoundReference(0, AttributeReference("col_name", StringType, nullable = false)()), - BoundReference(1, AttributeReference("data_type", StringType, nullable = false)()), - BoundReference(2, AttributeReference("comment", StringType, nullable = false)())) + AttributeReference("col_name", StringType, nullable = false)(), + AttributeReference("data_type", StringType, nullable = false)(), + AttributeReference("comment", StringType, nullable = false)()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe..e300bdbece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -72,7 +72,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + // Only log if this is a rule that is supposed to run more than once. + if (iteration != 2) { + logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + } continue = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cd4b5e9c1b..71808f76d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -23,16 +23,13 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * A JVM-global lock that should be used to prevent thread safety issues when using things in - * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for - * 2.10.* builds. See SI-6240 for more details. + * Utility functions for working with DataTypes. */ -protected[catalyst] object ScalaReflectionLock - object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -99,6 +96,13 @@ abstract class DataType { case object NullType extends DataType +object NativeType { + def all = Seq( + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + def unapply(dt: DataType): Boolean = all.contains(dt) +} + trait PrimitiveType extends DataType { override def isPrimitive = true } @@ -149,6 +153,10 @@ abstract class NumericType extends NativeType with PrimitiveType { val numeric: Numeric[JvmType] } +object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + /** Matcher for any expressions that evaluate to [[IntegralType]]s */ object IntegralType { def unapply(a: Expression): Boolean = a match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 58f8c341e6..999c9fff38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).eval(null) === 2) + checkEvaluation(Literal(1), 1) + checkEvaluation(Literal(true), true) + checkEvaluation(Literal(0L), 0L) + checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal(1) + Literal(1), 2) } /** @@ -61,10 +65,8 @@ class ExpressionEvaluationSuite extends FunSuite { test("3VL Not") { notTrueTable.foreach { case (v, answer) => - val expr = ! Literal(v, BooleanType) - val result = expr.eval(null) - if (result != answer) - fail(s"$expr should not evaluate to $result, expected: $answer") } + checkEvaluation(!Literal(v, BooleanType), answer) + } } booleanLogicTest("AND", _ && _, @@ -127,6 +129,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + test("IN") { + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) @@ -232,21 +241,21 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) - checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) - checkEvaluation("23" cast FloatType, 23) - checkEvaluation("23" cast DecimalType, 23) - checkEvaluation("23" cast ByteType, 23) - checkEvaluation("23" cast ShortType, 23) + checkEvaluation("23" cast FloatType, 23f) + checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast ByteType, 23.toByte) + checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) - checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) - checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) - checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} @@ -391,21 +400,21 @@ class ExpressionEvaluationSuite extends FunSuite { val typeMap = MapType(StringType, StringType) val typeArray = ArrayType(StringType) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal("aa")), "bb", row) checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), + checkEvaluation(GetItem(BoundReference(3, typeMap, true), Literal(null, StringType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(1)), "bb", row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) - checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), + checkEvaluation(GetItem(BoundReference(4, typeArray, true), Literal(null, IntegerType)), null, row) - checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) + checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) val typeS_notNullable = StructType( @@ -413,10 +422,8 @@ class ExpressionEvaluationSuite extends FunSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS)()), "a").nullable === true) - assert(GetField(BoundReference(2, - AttributeReference("c", typeS_notNullable, nullable = false)()), "a").nullable === false) + assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) assert(GetField(Literal(null, typeS), "a").nullable === true) assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 0000000000..245a2e1480 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = GenerateProjection.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10.seconds)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 0000000000..887aabb1d5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b955..e2ae0d25db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5d85a0fd4e..2d407077be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,8 +24,11 @@ import scala.collection.JavaConverters._ object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + val CODEGEN_ENABLED = "spark.sql.codegen" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -57,6 +60,18 @@ trait SQLConf { private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt /** + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (get(CODEGEN_ENABLED, "false") == "true") true else false + + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. @@ -111,5 +126,5 @@ trait SQLConf { private[spark] def clear() { settings.clear() } - } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c2bdef7323..e4b6810180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. @@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext) conf: Configuration = new Configuration()): SchemaRDD = { new SchemaRDD( this, - ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf)) + ParquetRelation.createEmpty( + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) @@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 806097c917..85726bae54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { conf: Configuration = new Configuration()): JavaSchemaRDD = { new JavaSchemaRDD( sqlContext, - ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf)) + ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext)) } /** @@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD( sqlContext, - ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8bfa4..463a1d32d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,8 +42,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + child: SparkPlan) + extends UnaryNode { override def requiredChildDistribution = if (partial) { @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output @@ -138,7 +136,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +150,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +173,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e7..392a7f3be3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning @@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00262..c386fd121c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,23 +47,26 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -72,7 +75,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala new file mode 100644 index 0000000000..4a26934c49 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + +/** + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class GeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def output = aggregateExpressions.map(_.toAttribute) + + override def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } + + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + } + + val computationSchema = computeFunctions.flatMap(_.schema) + + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext = resultIterator.hasNext + + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 77c874d031..21cbbc9772 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,22 +18,55 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row, SQLContext} + + +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() + + protected def sparkContext = sqlContext.sparkContext + + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } + + /** Overridden make copy also propogates sqlContext to copied plan. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + SparkPlan.currentContext.set(sqlContext) + super.makeCopy(newArgs) + } + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } + } + + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 404d48ae05..5f1fe99f75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.util.Try - import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = - if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) { + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { BuildRight } else { BuildLeft @@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + // Aggregations that can be performed in two phases, before and after the shuffle. - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - - // Construct two phased aggregation. - execution.Aggregate( + // Cases where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => + execution.GeneratedAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + planLater(child))) :: Nil + + // Cases where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), joinType, condition) :: Nil case _ => Nil } } @@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil + case logical.Except(left, right) => + execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 966d8f95fc..174eda8f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** @@ -189,15 +187,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d2f6..5ef46c32d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11caae..2750ddbce8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -217,9 +219,8 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + right: SparkPlan) extends BinaryNode with HashJoin { - override def otherCopyArgs = sqlContext :: Nil override def outputPartitioning: Partitioning = left.outputPartitioning @@ -228,7 +229,7 @@ case class BroadcastHashJoin( @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -248,14 +249,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -271,7 +269,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -300,8 +298,14 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } } } @@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - override def output = { joinType match { case LeftOuter => @@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 @@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin( // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() matched = true includedBroadcastTuples += i } @@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin( } if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + matchedRows += joinedRow(streamedRow, rightNulls).copy() } } Iterator((matchedRows, includedBroadcastTuples)) @@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin( streamedPlusMatches.map(_._2).reduce(_ ++ _) } + val leftNulls = new GenericMutableRow(left.output.size) val rightOuterMatches: Seq[Row] = if (joinType == RightOuter || joinType == FullOuter) { broadcastedRelation.value.zipWithIndex.filter { case (row, i) => !allIncludedBroadcastTuples.contains(i) }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + case (row, _) => new JoinedRow(leftNulls, row) } } else { Vector() } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 8c7dbd5eb4..b3bae5db0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) + @transient conf: Option[Configuration], + @transient sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation { self: Product => @@ -61,7 +62,7 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - override def newInstance = ParquetRelation(path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, @@ -70,6 +71,9 @@ private[sql] case class ParquetRelation( p.path == path && p.output == output case _ => false } + + // TODO: Use data from the footers. + override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -106,13 +110,14 @@ private[sql] object ParquetRelation { */ def create(pathString: String, child: LogicalPlan, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } - createEmpty(pathString, child.output, false, conf) + createEmpty(pathString, child.output, false, conf, sqlContext) } /** @@ -127,14 +132,15 @@ private[sql] object ParquetRelation { def createEmpty(pathString: String, attributes: Seq[Attribute], allowExisting: Boolean, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf)) { + new ParquetRelation(path.toString, Some(conf), sqlContext) { override val output = attributes } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ea74320d06..912a9f002b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -55,8 +55,7 @@ case class ParquetTableScan( // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { override def execute(): RDD[Row] = { @@ -99,8 +98,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -110,7 +107,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -150,8 +147,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -171,7 +167,7 @@ case class InsertIntoParquetTable( val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { - logger.debug("Initializing MutableRowWriteSupport") + log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { classOf[org.apache.spark.sql.parquet.RowWriteSupport] @@ -203,8 +199,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d4599da711..837ea7695d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup @@ -103,7 +104,7 @@ private[sql] object ParquetTestData { val testDir = Utils.createTempDir() val testFilterDir = Utils.createTempDir() - lazy val testData = new ParquetRelation(testDir.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) val testNestedSchema1 = // based on blogpost example, source: @@ -202,8 +203,10 @@ private[sql] object ParquetTestData { val testNestedDir3 = Utils.createTempDir() val testNestedDir4 = Utils.createTempDir() - lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) - lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + lazy val testNestedData1 = + new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) + lazy val testNestedData2 = + new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) def writeFile() = { testDir.delete() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e1971d9..1fd8d27b34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e852..76b1724471 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b8ed..2cab5e0c44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3c911e9a4e..561f5b4a49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { + SparkPlan.currentContext.set(TestSQLContext) val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext) + Seq()) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 84d43eaeea..f0a61270da 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c2b0b00aa5..39033bdeac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -131,7 +131,7 @@ case class InsertIntoHiveTable( conf, SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) writer.preSetup() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5fef..0c8f676e9c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 057eb60a02..7582b4743d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) @@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 0000000000..00750edc07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 0000000000..00750edc07 --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 0000000000..d00491fd7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aadfd2e900..89cc589fb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} @@ -30,6 +32,15 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") + createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1") |