aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
commit7e191fe29bb09a8560cd75d453c4f7f662dff406 (patch)
tree38c9db34f2ff9dfd3042cacc3ca6fe351c576ece /sql/catalyst
parent2fb1c72ea21e137c8b60a72e5aecd554c71b16e1 (diff)
downloadspark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.gz
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.bz2
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.zip
[SPARK-2554][SQL] CountDistinct partial aggregation and object allocation improvements
Author: Michael Armbrust <michael@databricks.com> Author: Gregory Owen <greowen@gmail.com> Closes #1935 from marmbrus/countDistinctPartial and squashes the following commits: 5c7848d [Michael Armbrust] turn off caching in the constructor 8074a80 [Michael Armbrust] fix tests 32d216f [Michael Armbrust] reynolds comments c122cca [Michael Armbrust] Address comments, add tests b2e8ef3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial fae38f4 [Michael Armbrust] Fix style fdca896 [Michael Armbrust] cleanup 93d0f64 [Michael Armbrust] metastore concurrency fix. db44a30 [Michael Armbrust] JIT hax. 3868f6c [Michael Armbrust] Merge pull request #9 from GregOwen/countDistinctPartial c9e67de [Gregory Owen] Made SpecificRow and types serializable by Kryo 2b46c4b [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 8ff6402 [Michael Armbrust] Add specific row. 58d15f1 [Michael Armbrust] disable codegen logging 87d101d [Michael Armbrust] Fix isNullAt bug abee26d [Michael Armbrust] WIP 27984d0 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 57ae3b1 [Michael Armbrust] Fix order dependent test b3d0f64 [Michael Armbrust] Add golden files. c1f7114 [Michael Armbrust] Improve tests / fix serialization. f31b8ad [Michael Armbrust] more fixes 38c7449 [Michael Armbrust] comments and style 9153652 [Michael Armbrust] better toString d494598 [Michael Armbrust] Fix tests now that the planner is better 41fbd1d [Michael Armbrust] Never try and create an empty hash set. 050bb97 [Michael Armbrust] Skip no-arg constructors for kryo, bd08239 [Michael Armbrust] WIP 213ada8 [Michael Armbrust] First draft of partially aggregated and code generated count distinct / max
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala344
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala307
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala129
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala10
10 files changed, 1006 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 8fc5896974..ef1d12531f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
- protected val exprArray = expressions.toArray
+ // null check is required for when Kryo invokes the no-arg constructor.
+ protected val exprArray = if (expressions != null) expressions.toArray else null
def apply(input: Row): Row = {
val outputArray = new Array[Any](exprArray.length)
@@ -109,7 +110,346 @@ class JoinedRow extends Row {
def apply(i: Int) =
if (i < row1.size) row1(i) else row2(i - row1.size)
- def isNullAt(i: Int) = apply(i) == null
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there
+ * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the
+ * calls in the critical path are polymorphic. By creating special versions of this class that are
+ * used in only a single location of the code, we increase the chance that only a single type of
+ * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds
+ * crazy but in benchmarks it had noticeable effects.
+ */
+class JoinedRow2 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow3 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow4 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow5 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
def getInt(i: Int): Int =
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index c9a63e201e..d68a4fabea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -127,7 +127,7 @@ object EmptyRow extends Row {
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
-class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
+class GenericRow(protected[sql] val values: Array[Any]) extends Row {
/** No-arg constructor for serialization. */
def this() = this(null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
new file mode 100644
index 0000000000..75ea0e8459
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A parent class for mutable container objects that are reused when the values are changed,
+ * resulting in less garbage. These values are held by a [[SpecificMutableRow]].
+ *
+ * The following code was roughly used to generate these objects:
+ * {{{
+ * val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",")
+ * types.map {tpe =>
+ * s"""
+ * final class Mutable$tpe extends MutableValue {
+ * var value: $tpe = 0
+ * def boxed = if (isNull) null else value
+ * def update(v: Any) = value = {
+ * isNull = false
+ * v.asInstanceOf[$tpe]
+ * }
+ * def copy() = {
+ * val newCopy = new Mutable$tpe
+ * newCopy.isNull = isNull
+ * newCopy.value = value
+ * newCopy.asInstanceOf[this.type]
+ * }
+ * }"""
+ * }.foreach(println)
+ *
+ * types.map { tpe =>
+ * s"""
+ * override def set$tpe(ordinal: Int, value: $tpe): Unit = {
+ * val currentValue = values(ordinal).asInstanceOf[Mutable$tpe]
+ * currentValue.isNull = false
+ * currentValue.value = value
+ * }
+ *
+ * override def get$tpe(i: Int): $tpe = {
+ * values(i).asInstanceOf[Mutable$tpe].value
+ * }"""
+ * }.foreach(println)
+ * }}}
+ */
+abstract class MutableValue extends Serializable {
+ var isNull: Boolean = true
+ def boxed: Any
+ def update(v: Any)
+ def copy(): this.type
+}
+
+final class MutableInt extends MutableValue {
+ var value: Int = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Int]
+ }
+ def copy() = {
+ val newCopy = new MutableInt
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableFloat extends MutableValue {
+ var value: Float = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Float]
+ }
+ def copy() = {
+ val newCopy = new MutableFloat
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableBoolean extends MutableValue {
+ var value: Boolean = false
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Boolean]
+ }
+ def copy() = {
+ val newCopy = new MutableBoolean
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableDouble extends MutableValue {
+ var value: Double = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Double]
+ }
+ def copy() = {
+ val newCopy = new MutableDouble
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableShort extends MutableValue {
+ var value: Short = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Short]
+ }
+ def copy() = {
+ val newCopy = new MutableShort
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableLong extends MutableValue {
+ var value: Long = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Long]
+ }
+ def copy() = {
+ val newCopy = new MutableLong
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableByte extends MutableValue {
+ var value: Byte = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Byte]
+ }
+ def copy() = {
+ val newCopy = new MutableByte
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableAny extends MutableValue {
+ var value: Any = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Any]
+ }
+ def copy() = {
+ val newCopy = new MutableAny
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+/**
+ * A row type that holds an array specialized container objects, of type [[MutableValue]], chosen
+ * based on the dataTypes of each column. The intent is to decrease garbage when modifying the
+ * values of primitive columns.
+ */
+final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow {
+
+ def this(dataTypes: Seq[DataType]) =
+ this(
+ dataTypes.map {
+ case IntegerType => new MutableInt
+ case ByteType => new MutableByte
+ case FloatType => new MutableFloat
+ case ShortType => new MutableShort
+ case DoubleType => new MutableDouble
+ case BooleanType => new MutableBoolean
+ case LongType => new MutableLong
+ case _ => new MutableAny
+ }.toArray)
+
+ def this() = this(Seq.empty)
+
+ override def length: Int = values.length
+
+ override def setNullAt(i: Int): Unit = {
+ values(i).isNull = true
+ }
+
+ override def apply(i: Int): Any = values(i).boxed
+
+ override def isNullAt(i: Int): Boolean = values(i).isNull
+
+ override def copy(): Row = {
+ val newValues = new Array[MutableValue](values.length)
+ var i = 0
+ while (i < values.length) {
+ newValues(i) = values(i).copy()
+ i += 1
+ }
+ new SpecificMutableRow(newValues)
+ }
+
+ override def update(ordinal: Int, value: Any): Unit = values(ordinal).update(value)
+
+ override def iterator: Iterator[Any] = values.map(_.boxed).iterator
+
+ def setString(ordinal: Int, value: String) = update(ordinal, value)
+
+ def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+
+ override def setInt(ordinal: Int, value: Int): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableInt]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getInt(i: Int): Int = {
+ values(i).asInstanceOf[MutableInt].value
+ }
+
+ override def setFloat(ordinal: Int, value: Float): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableFloat]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getFloat(i: Int): Float = {
+ values(i).asInstanceOf[MutableFloat].value
+ }
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableBoolean]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getBoolean(i: Int): Boolean = {
+ values(i).asInstanceOf[MutableBoolean].value
+ }
+
+ override def setDouble(ordinal: Int, value: Double): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableDouble]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getDouble(i: Int): Double = {
+ values(i).asInstanceOf[MutableDouble].value
+ }
+
+ override def setShort(ordinal: Int, value: Short): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableShort]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getShort(i: Int): Short = {
+ values(i).asInstanceOf[MutableShort].value
+ }
+
+ override def setLong(ordinal: Int, value: Long): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableLong]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getLong(i: Int): Long = {
+ values(i).asInstanceOf[MutableLong].value
+ }
+
+ override def setByte(ordinal: Int, value: Byte): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableByte]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getByte(i: Int): Byte = {
+ values(i).asInstanceOf[MutableByte].value
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 01947273b6..613b87ca98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.util.collection.OpenHashSet
abstract class AggregateExpression extends Expression {
self: Product =>
@@ -161,13 +162,88 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance() = new CountFunction(child, this)
}
-case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
+case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
+ def this() = this(null)
+
override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)
+
+ override def asPartial = {
+ val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
+ SplitEvaluation(
+ CombineSetsAndCount(partialSet.toAttribute),
+ partialSet :: Nil)
+ }
+}
+
+case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = expressions
+ override def references = expressions.flatMap(_.references).toSet
+ override def nullable = false
+ override def dataType = ArrayType(expressions.head.dataType)
+ override def toString = s"AddToHashSet(${expressions.mkString(",")})"
+ override def newInstance() = new CollectHashSetFunction(expressions, this)
+}
+
+case class CollectHashSetFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
+
+ override def update(input: Row): Unit = {
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ seen
+ }
+}
+
+case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = inputSet :: Nil
+ override def references = inputSet.references
+ override def nullable = false
+ override def dataType = LongType
+ override def toString = s"CombineAndCount($inputSet)"
+ override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
+}
+
+case class CombineSetsAndCountFunction(
+ @transient inputSet: Expression,
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ override def update(input: Row): Unit = {
+ val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+ val inputIterator = inputSetEval.iterator
+ while (inputIterator.hasNext) {
+ seen.add(inputIterator.next)
+ }
+ }
+
+ override def eval(input: Row): Any = seen.size.toLong
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
@@ -379,17 +455,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
-case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression)
+case class CountDistinctFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- val seen = new scala.collection.mutable.HashSet[Any]()
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
override def update(input: Row): Unit = {
- val evaluatedExpr = expr.map(_.eval(input))
- if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
- seen += evaluatedExpr
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index c79c1847ce..8d90614e45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -85,3 +85,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}
+
+case class MaxOf(left: Expression, right: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ override def nullable = left.nullable && right.nullable
+
+ override def children = left :: right :: Nil
+
+ override def references = left.references ++ right.references
+
+ override def dataType = left.dataType
+
+ override def eval(input: Row): Any = {
+ val leftEval = left.eval(input)
+ val rightEval = right.eval(input)
+ if (leftEval == null) {
+ rightEval
+ } else if (rightEval == null) {
+ leftEval
+ } else {
+ val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
+ if (numeric.compare(leftEval, rightEval) < 0) {
+ rightEval
+ } else {
+ leftEval
+ }
+ }
+ }
+
+ override def toString = s"MaxOf($left, $right)"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index de2d67ce82..5a3f013c34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
+// These classes are here to avoid issues with serialization and integration with quasiquotes.
+class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
+class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
+
/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
@@ -51,6 +55,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private val javaSeparator = "$"
/**
+ * Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
+ */
+ var debugLogging = false
+
+ /**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
*/
@@ -71,7 +80,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
* fundamental difference is that a ConcurrentMap persists all elements that are added to it until
* they are explicitly removed. A Cache on the other hand is generally configured to evict entries
- * automatically, in order to constrain its memory footprint
+ * automatically, in order to constrain its memory footprint. Note that this cache does not use
+ * weak keys/values and thus does not respond to memory pressure.
*/
protected val cache = CacheBuilder.newBuilder()
.maximumSize(1000)
@@ -403,6 +413,78 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
$primitiveTerm = ${falseEval.primitiveTerm}
}
""".children
+
+ case NewSet(elementType) =>
+ q"""
+ val $nullTerm = false
+ val $primitiveTerm = new ${hashSetForType(elementType)}()
+ """.children
+
+ case AddItemToSet(item, set) =>
+ val itemEval = expressionEvaluator(item)
+ val setEval = expressionEvaluator(set)
+
+ val ArrayType(elementType, _) = set.dataType
+
+ itemEval.code ++ setEval.code ++
+ q"""
+ if (!${itemEval.nullTerm}) {
+ ${setEval.primitiveTerm}
+ .asInstanceOf[${hashSetForType(elementType)}]
+ .add(${itemEval.primitiveTerm})
+ }
+
+ val $nullTerm = false
+ val $primitiveTerm = ${setEval.primitiveTerm}
+ """.children
+
+ case CombineSets(left, right) =>
+ val leftEval = expressionEvaluator(left)
+ val rightEval = expressionEvaluator(right)
+
+ val ArrayType(elementType, _) = left.dataType
+
+ leftEval.code ++ rightEval.code ++
+ q"""
+ val $nullTerm = false
+ var $primitiveTerm: ${hashSetForType(elementType)} = null
+
+ {
+ val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
+ val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
+ val iterator = rightSet.iterator
+ while (iterator.hasNext) {
+ leftSet.add(iterator.next())
+ }
+ $primitiveTerm = leftSet
+ }
+ """.children
+
+ case MaxOf(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+
+ if (${eval1.nullTerm}) {
+ $nullTerm = ${eval2.nullTerm}
+ $primitiveTerm = ${eval2.primitiveTerm}
+ } else if (${eval2.nullTerm}) {
+ $nullTerm = ${eval1.nullTerm}
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ $nullTerm = false
+ if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ $primitiveTerm = ${eval2.primitiveTerm}
+ }
+ }
+ """.children
+
}
// If there was no match in the partial function above, we fall back on calling the interpreted
@@ -420,7 +502,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
// Only inject debugging code if debugging is turned on.
val debugCode =
- if (log.isDebugEnabled) {
+ if (debugLogging) {
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
@@ -454,6 +536,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+ protected def hashSetForType(dt: DataType) = dt match {
+ case IntegerType => typeOf[IntegerHashSet]
+ case LongType => typeOf[LongHashSet]
+ case unsupportedType =>
+ sys.error(s"Code generation not support for hashset of type $unsupportedType")
+ }
+
protected def primitiveForType(dt: DataType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 77fa02c13d..7871a62620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
..${evaluatedExpression.code}
if(${evaluatedExpression.nullTerm})
setNullAt($iLit)
- else
+ else {
+ nullBits($iLit) = false
$elementName = ${evaluatedExpression.primitiveTerm}
+ }
}
""".children : Seq[Tree]
}
@@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if(value == null) {
setNullAt(i)
} else {
+ nullBits(i) = false
$elementName = value.asInstanceOf[${termForType(e.dataType)}]
- return
}
+ return
}"""
}
q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
@@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
- q"if(i == $i) { $elementName = value; return }" :: Nil
+ q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
new file mode 100644
index 0000000000..e6c570b47b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.util.collection.OpenHashSet
+
+/**
+ * Creates a new set of the specified type
+ */
+case class NewSet(elementType: DataType) extends LeafExpression {
+ type EvaluatedType = Any
+
+ def references = Set.empty
+
+ def nullable = false
+
+ // We are currently only using these Expressions internally for aggregation. However, if we ever
+ // expose these to users we'll want to create a proper type instead of hijacking ArrayType.
+ def dataType = ArrayType(elementType)
+
+ def eval(input: Row): Any = {
+ new OpenHashSet[Any]()
+ }
+
+ override def toString = s"new Set($dataType)"
+}
+
+/**
+ * Adds an item to a set.
+ * For performance, this expression mutates its input during evaluation.
+ */
+case class AddItemToSet(item: Expression, set: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ def children = item :: set :: Nil
+
+ def nullable = set.nullable
+
+ def dataType = set.dataType
+
+ def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet
+
+ def eval(input: Row): Any = {
+ val itemEval = item.eval(input)
+ val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
+
+ if (itemEval != null) {
+ if (setEval != null) {
+ setEval.add(itemEval)
+ setEval
+ } else {
+ null
+ }
+ } else {
+ setEval
+ }
+ }
+
+ override def toString = s"$set += $item"
+}
+
+/**
+ * Combines the elements of two sets.
+ * For performance, this expression mutates its left input set during evaluation.
+ */
+case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
+ type EvaluatedType = Any
+
+ def nullable = left.nullable || right.nullable
+
+ def dataType = left.dataType
+
+ def symbol = "++="
+
+ def eval(input: Row): Any = {
+ val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if(leftEval != null) {
+ val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if (rightEval != null) {
+ val iterator = rightEval.iterator
+ while(iterator.hasNext) {
+ val rightValue = iterator.next()
+ leftEval.add(rightValue)
+ }
+ leftEval
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+}
+
+/**
+ * Returns the number of elements in the input set.
+ */
+case class CountSet(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def nullable = child.nullable
+
+ def dataType = LongType
+
+ def eval(input: Row): Any = {
+ val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if (childEval != null) {
+ childEval.size.toLong
+ }
+ }
+
+ override def toString = s"$child.count()"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index cd04bdf02c..96ce35939e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -280,7 +280,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
try {
- val defaultCtor = getClass.getConstructors.head
+ // Skip no-arg constructors that are just there for kryo.
+ val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 999c9fff38..f1df817c41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -136,6 +136,16 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("MaxOf") {
+ checkEvaluation(MaxOf(1, 2), 2)
+ checkEvaluation(MaxOf(2, 1), 2)
+ checkEvaluation(MaxOf(1L, 2L), 2L)
+ checkEvaluation(MaxOf(2L, 1L), 2L)
+
+ checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2)
+ checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)