aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
commit7e191fe29bb09a8560cd75d453c4f7f662dff406 (patch)
tree38c9db34f2ff9dfd3042cacc3ca6fe351c576ece /sql
parent2fb1c72ea21e137c8b60a72e5aecd554c71b16e1 (diff)
downloadspark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.gz
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.bz2
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.zip
[SPARK-2554][SQL] CountDistinct partial aggregation and object allocation improvements
Author: Michael Armbrust <michael@databricks.com> Author: Gregory Owen <greowen@gmail.com> Closes #1935 from marmbrus/countDistinctPartial and squashes the following commits: 5c7848d [Michael Armbrust] turn off caching in the constructor 8074a80 [Michael Armbrust] fix tests 32d216f [Michael Armbrust] reynolds comments c122cca [Michael Armbrust] Address comments, add tests b2e8ef3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial fae38f4 [Michael Armbrust] Fix style fdca896 [Michael Armbrust] cleanup 93d0f64 [Michael Armbrust] metastore concurrency fix. db44a30 [Michael Armbrust] JIT hax. 3868f6c [Michael Armbrust] Merge pull request #9 from GregOwen/countDistinctPartial c9e67de [Gregory Owen] Made SpecificRow and types serializable by Kryo 2b46c4b [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 8ff6402 [Michael Armbrust] Add specific row. 58d15f1 [Michael Armbrust] disable codegen logging 87d101d [Michael Armbrust] Fix isNullAt bug abee26d [Michael Armbrust] WIP 27984d0 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 57ae3b1 [Michael Armbrust] Fix order dependent test b3d0f64 [Michael Armbrust] Add golden files. c1f7114 [Michael Armbrust] Improve tests / fix serialization. f31b8ad [Michael Armbrust] more fixes 38c7449 [Michael Armbrust] comments and style 9153652 [Michael Armbrust] better toString d494598 [Michael Armbrust] Fix tests now that the planner is better 41fbd1d [Michael Armbrust] Never try and create an empty hash set. 050bb97 [Michael Armbrust] Skip no-arg constructors for kryo, bd08239 [Michael Armbrust] WIP 213ada8 [Michael Armbrust] First draft of partially aggregated and code generated count distinct / max
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala344
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala307
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala93
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala129
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala86
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala4
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc1
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c91
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f71
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff1
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d1
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e0381
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e1
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d11
-rw-r--r--sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e1
-rw-r--r--sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce0
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala65
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala5
33 files changed, 1239 insertions, 34 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 8fc5896974..ef1d12531f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
- protected val exprArray = expressions.toArray
+ // null check is required for when Kryo invokes the no-arg constructor.
+ protected val exprArray = if (expressions != null) expressions.toArray else null
def apply(input: Row): Row = {
val outputArray = new Array[Any](exprArray.length)
@@ -109,7 +110,346 @@ class JoinedRow extends Row {
def apply(i: Int) =
if (i < row1.size) row1(i) else row2(i - row1.size)
- def isNullAt(i: Int) = apply(i) == null
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there
+ * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the
+ * calls in the critical path are polymorphic. By creating special versions of this class that are
+ * used in only a single location of the code, we increase the chance that only a single type of
+ * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds
+ * crazy but in benchmarks it had noticeable effects.
+ */
+class JoinedRow2 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow3 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow4 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
+
+ def getInt(i: Int): Int =
+ if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
+
+ def getLong(i: Int): Long =
+ if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
+
+ def getDouble(i: Int): Double =
+ if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
+
+ def getBoolean(i: Int): Boolean =
+ if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
+
+ def getShort(i: Int): Short =
+ if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
+
+ def getByte(i: Int): Byte =
+ if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
+
+ def getFloat(i: Int): Float =
+ if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
+
+ def getString(i: Int): String =
+ if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+
+ def copy() = {
+ val totalSize = row1.size + row2.size
+ val copiedValues = new Array[Any](totalSize)
+ var i = 0
+ while(i < totalSize) {
+ copiedValues(i) = apply(i)
+ i += 1
+ }
+ new GenericRow(copiedValues)
+ }
+
+ override def toString() = {
+ val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
+ s"[${row.mkString(",")}]"
+ }
+}
+
+/**
+ * JIT HACK: Replace with macros
+ */
+class JoinedRow5 extends Row {
+ private[this] var row1: Row = _
+ private[this] var row2: Row = _
+
+ def this(left: Row, right: Row) = {
+ this()
+ row1 = left
+ row2 = right
+ }
+
+ /** Updates this JoinedRow to used point at two new base rows. Returns itself. */
+ def apply(r1: Row, r2: Row): Row = {
+ row1 = r1
+ row2 = r2
+ this
+ }
+
+ /** Updates this JoinedRow by updating its left base row. Returns itself. */
+ def withLeft(newLeft: Row): Row = {
+ row1 = newLeft
+ this
+ }
+
+ /** Updates this JoinedRow by updating its right base row. Returns itself. */
+ def withRight(newRight: Row): Row = {
+ row2 = newRight
+ this
+ }
+
+ def iterator = row1.iterator ++ row2.iterator
+
+ def length = row1.length + row2.length
+
+ def apply(i: Int) =
+ if (i < row1.size) row1(i) else row2(i - row1.size)
+
+ def isNullAt(i: Int) =
+ if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
def getInt(i: Int): Int =
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index c9a63e201e..d68a4fabea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -127,7 +127,7 @@ object EmptyRow extends Row {
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
-class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
+class GenericRow(protected[sql] val values: Array[Any]) extends Row {
/** No-arg constructor for serialization. */
def this() = this(null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
new file mode 100644
index 0000000000..75ea0e8459
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * A parent class for mutable container objects that are reused when the values are changed,
+ * resulting in less garbage. These values are held by a [[SpecificMutableRow]].
+ *
+ * The following code was roughly used to generate these objects:
+ * {{{
+ * val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",")
+ * types.map {tpe =>
+ * s"""
+ * final class Mutable$tpe extends MutableValue {
+ * var value: $tpe = 0
+ * def boxed = if (isNull) null else value
+ * def update(v: Any) = value = {
+ * isNull = false
+ * v.asInstanceOf[$tpe]
+ * }
+ * def copy() = {
+ * val newCopy = new Mutable$tpe
+ * newCopy.isNull = isNull
+ * newCopy.value = value
+ * newCopy.asInstanceOf[this.type]
+ * }
+ * }"""
+ * }.foreach(println)
+ *
+ * types.map { tpe =>
+ * s"""
+ * override def set$tpe(ordinal: Int, value: $tpe): Unit = {
+ * val currentValue = values(ordinal).asInstanceOf[Mutable$tpe]
+ * currentValue.isNull = false
+ * currentValue.value = value
+ * }
+ *
+ * override def get$tpe(i: Int): $tpe = {
+ * values(i).asInstanceOf[Mutable$tpe].value
+ * }"""
+ * }.foreach(println)
+ * }}}
+ */
+abstract class MutableValue extends Serializable {
+ var isNull: Boolean = true
+ def boxed: Any
+ def update(v: Any)
+ def copy(): this.type
+}
+
+final class MutableInt extends MutableValue {
+ var value: Int = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Int]
+ }
+ def copy() = {
+ val newCopy = new MutableInt
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableFloat extends MutableValue {
+ var value: Float = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Float]
+ }
+ def copy() = {
+ val newCopy = new MutableFloat
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableBoolean extends MutableValue {
+ var value: Boolean = false
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Boolean]
+ }
+ def copy() = {
+ val newCopy = new MutableBoolean
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableDouble extends MutableValue {
+ var value: Double = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Double]
+ }
+ def copy() = {
+ val newCopy = new MutableDouble
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableShort extends MutableValue {
+ var value: Short = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Short]
+ }
+ def copy() = {
+ val newCopy = new MutableShort
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableLong extends MutableValue {
+ var value: Long = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Long]
+ }
+ def copy() = {
+ val newCopy = new MutableLong
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableByte extends MutableValue {
+ var value: Byte = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Byte]
+ }
+ def copy() = {
+ val newCopy = new MutableByte
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+final class MutableAny extends MutableValue {
+ var value: Any = 0
+ def boxed = if (isNull) null else value
+ def update(v: Any) = value = {
+ isNull = false
+ v.asInstanceOf[Any]
+ }
+ def copy() = {
+ val newCopy = new MutableAny
+ newCopy.isNull = isNull
+ newCopy.value = value
+ newCopy.asInstanceOf[this.type]
+ }
+}
+
+/**
+ * A row type that holds an array specialized container objects, of type [[MutableValue]], chosen
+ * based on the dataTypes of each column. The intent is to decrease garbage when modifying the
+ * values of primitive columns.
+ */
+final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow {
+
+ def this(dataTypes: Seq[DataType]) =
+ this(
+ dataTypes.map {
+ case IntegerType => new MutableInt
+ case ByteType => new MutableByte
+ case FloatType => new MutableFloat
+ case ShortType => new MutableShort
+ case DoubleType => new MutableDouble
+ case BooleanType => new MutableBoolean
+ case LongType => new MutableLong
+ case _ => new MutableAny
+ }.toArray)
+
+ def this() = this(Seq.empty)
+
+ override def length: Int = values.length
+
+ override def setNullAt(i: Int): Unit = {
+ values(i).isNull = true
+ }
+
+ override def apply(i: Int): Any = values(i).boxed
+
+ override def isNullAt(i: Int): Boolean = values(i).isNull
+
+ override def copy(): Row = {
+ val newValues = new Array[MutableValue](values.length)
+ var i = 0
+ while (i < values.length) {
+ newValues(i) = values(i).copy()
+ i += 1
+ }
+ new SpecificMutableRow(newValues)
+ }
+
+ override def update(ordinal: Int, value: Any): Unit = values(ordinal).update(value)
+
+ override def iterator: Iterator[Any] = values.map(_.boxed).iterator
+
+ def setString(ordinal: Int, value: String) = update(ordinal, value)
+
+ def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+
+ override def setInt(ordinal: Int, value: Int): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableInt]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getInt(i: Int): Int = {
+ values(i).asInstanceOf[MutableInt].value
+ }
+
+ override def setFloat(ordinal: Int, value: Float): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableFloat]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getFloat(i: Int): Float = {
+ values(i).asInstanceOf[MutableFloat].value
+ }
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableBoolean]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getBoolean(i: Int): Boolean = {
+ values(i).asInstanceOf[MutableBoolean].value
+ }
+
+ override def setDouble(ordinal: Int, value: Double): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableDouble]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getDouble(i: Int): Double = {
+ values(i).asInstanceOf[MutableDouble].value
+ }
+
+ override def setShort(ordinal: Int, value: Short): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableShort]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getShort(i: Int): Short = {
+ values(i).asInstanceOf[MutableShort].value
+ }
+
+ override def setLong(ordinal: Int, value: Long): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableLong]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getLong(i: Int): Long = {
+ values(i).asInstanceOf[MutableLong].value
+ }
+
+ override def setByte(ordinal: Int, value: Byte): Unit = {
+ val currentValue = values(ordinal).asInstanceOf[MutableByte]
+ currentValue.isNull = false
+ currentValue.value = value
+ }
+
+ override def getByte(i: Int): Byte = {
+ values(i).asInstanceOf[MutableByte].value
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 01947273b6..613b87ca98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.util.collection.OpenHashSet
abstract class AggregateExpression extends Expression {
self: Product =>
@@ -161,13 +162,88 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance() = new CountFunction(child, this)
}
-case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
+case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
+ def this() = this(null)
+
override def children = expressions
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = LongType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)
+
+ override def asPartial = {
+ val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
+ SplitEvaluation(
+ CombineSetsAndCount(partialSet.toAttribute),
+ partialSet :: Nil)
+ }
+}
+
+case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = expressions
+ override def references = expressions.flatMap(_.references).toSet
+ override def nullable = false
+ override def dataType = ArrayType(expressions.head.dataType)
+ override def toString = s"AddToHashSet(${expressions.mkString(",")})"
+ override def newInstance() = new CollectHashSetFunction(expressions, this)
+}
+
+case class CollectHashSetFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
+
+ override def update(input: Row): Unit = {
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ seen
+ }
+}
+
+case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
+ def this() = this(null)
+
+ override def children = inputSet :: Nil
+ override def references = inputSet.references
+ override def nullable = false
+ override def dataType = LongType
+ override def toString = s"CombineAndCount($inputSet)"
+ override def newInstance() = new CombineSetsAndCountFunction(inputSet, this)
+}
+
+case class CombineSetsAndCountFunction(
+ @transient inputSet: Expression,
+ @transient base: AggregateExpression)
+ extends AggregateFunction {
+
+ def this() = this(null, null) // Required for serialization.
+
+ val seen = new OpenHashSet[Any]()
+
+ override def update(input: Row): Unit = {
+ val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+ val inputIterator = inputSetEval.iterator
+ while (inputIterator.hasNext) {
+ seen.add(inputIterator.next)
+ }
+ }
+
+ override def eval(input: Row): Any = seen.size.toLong
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
@@ -379,17 +455,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
-case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression)
+case class CountDistinctFunction(
+ @transient expr: Seq[Expression],
+ @transient base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- val seen = new scala.collection.mutable.HashSet[Any]()
+ val seen = new OpenHashSet[Any]()
+
+ @transient
+ val distinctValue = new InterpretedProjection(expr)
override def update(input: Row): Unit = {
- val evaluatedExpr = expr.map(_.eval(input))
- if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
- seen += evaluatedExpr
+ val evaluatedExpr = distinctValue(input)
+ if (!evaluatedExpr.anyNull) {
+ seen.add(evaluatedExpr)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index c79c1847ce..8d90614e45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -85,3 +85,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}
+
+case class MaxOf(left: Expression, right: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ override def nullable = left.nullable && right.nullable
+
+ override def children = left :: right :: Nil
+
+ override def references = left.references ++ right.references
+
+ override def dataType = left.dataType
+
+ override def eval(input: Row): Any = {
+ val leftEval = left.eval(input)
+ val rightEval = right.eval(input)
+ if (leftEval == null) {
+ rightEval
+ } else if (rightEval == null) {
+ leftEval
+ } else {
+ val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
+ if (numeric.compare(leftEval, rightEval) < 0) {
+ rightEval
+ } else {
+ leftEval
+ }
+ }
+ }
+
+ override def toString = s"MaxOf($left, $right)"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index de2d67ce82..5a3f013c34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
+// These classes are here to avoid issues with serialization and integration with quasiquotes.
+class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
+class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
+
/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
@@ -51,6 +55,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
private val javaSeparator = "$"
/**
+ * Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
+ */
+ var debugLogging = false
+
+ /**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
*/
@@ -71,7 +80,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most
* fundamental difference is that a ConcurrentMap persists all elements that are added to it until
* they are explicitly removed. A Cache on the other hand is generally configured to evict entries
- * automatically, in order to constrain its memory footprint
+ * automatically, in order to constrain its memory footprint. Note that this cache does not use
+ * weak keys/values and thus does not respond to memory pressure.
*/
protected val cache = CacheBuilder.newBuilder()
.maximumSize(1000)
@@ -403,6 +413,78 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
$primitiveTerm = ${falseEval.primitiveTerm}
}
""".children
+
+ case NewSet(elementType) =>
+ q"""
+ val $nullTerm = false
+ val $primitiveTerm = new ${hashSetForType(elementType)}()
+ """.children
+
+ case AddItemToSet(item, set) =>
+ val itemEval = expressionEvaluator(item)
+ val setEval = expressionEvaluator(set)
+
+ val ArrayType(elementType, _) = set.dataType
+
+ itemEval.code ++ setEval.code ++
+ q"""
+ if (!${itemEval.nullTerm}) {
+ ${setEval.primitiveTerm}
+ .asInstanceOf[${hashSetForType(elementType)}]
+ .add(${itemEval.primitiveTerm})
+ }
+
+ val $nullTerm = false
+ val $primitiveTerm = ${setEval.primitiveTerm}
+ """.children
+
+ case CombineSets(left, right) =>
+ val leftEval = expressionEvaluator(left)
+ val rightEval = expressionEvaluator(right)
+
+ val ArrayType(elementType, _) = left.dataType
+
+ leftEval.code ++ rightEval.code ++
+ q"""
+ val $nullTerm = false
+ var $primitiveTerm: ${hashSetForType(elementType)} = null
+
+ {
+ val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
+ val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
+ val iterator = rightSet.iterator
+ while (iterator.hasNext) {
+ leftSet.add(iterator.next())
+ }
+ $primitiveTerm = leftSet
+ }
+ """.children
+
+ case MaxOf(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+
+ if (${eval1.nullTerm}) {
+ $nullTerm = ${eval2.nullTerm}
+ $primitiveTerm = ${eval2.primitiveTerm}
+ } else if (${eval2.nullTerm}) {
+ $nullTerm = ${eval1.nullTerm}
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ $nullTerm = false
+ if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ $primitiveTerm = ${eval2.primitiveTerm}
+ }
+ }
+ """.children
+
}
// If there was no match in the partial function above, we fall back on calling the interpreted
@@ -420,7 +502,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
// Only inject debugging code if debugging is turned on.
val debugCode =
- if (log.isDebugEnabled) {
+ if (debugLogging) {
val localLogger = log
val localLoggerTree = reify { localLogger }
q"""
@@ -454,6 +536,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}")
protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}")
+ protected def hashSetForType(dt: DataType) = dt match {
+ case IntegerType => typeOf[IntegerHashSet]
+ case LongType => typeOf[LongHashSet]
+ case unsupportedType =>
+ sys.error(s"Code generation not support for hashset of type $unsupportedType")
+ }
+
protected def primitiveForType(dt: DataType) = dt match {
case IntegerType => "Int"
case LongType => "Long"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 77fa02c13d..7871a62620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
..${evaluatedExpression.code}
if(${evaluatedExpression.nullTerm})
setNullAt($iLit)
- else
+ else {
+ nullBits($iLit) = false
$elementName = ${evaluatedExpression.primitiveTerm}
+ }
}
""".children : Seq[Tree]
}
@@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if(value == null) {
setNullAt(i)
} else {
+ nullBits(i) = false
$elementName = value.asInstanceOf[${termForType(e.dataType)}]
- return
}
+ return
}"""
}
q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
@@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
- q"if(i == $i) { $elementName = value; return }" :: Nil
+ q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil
case _ => Nil
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
new file mode 100644
index 0000000000..e6c570b47b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.util.collection.OpenHashSet
+
+/**
+ * Creates a new set of the specified type
+ */
+case class NewSet(elementType: DataType) extends LeafExpression {
+ type EvaluatedType = Any
+
+ def references = Set.empty
+
+ def nullable = false
+
+ // We are currently only using these Expressions internally for aggregation. However, if we ever
+ // expose these to users we'll want to create a proper type instead of hijacking ArrayType.
+ def dataType = ArrayType(elementType)
+
+ def eval(input: Row): Any = {
+ new OpenHashSet[Any]()
+ }
+
+ override def toString = s"new Set($dataType)"
+}
+
+/**
+ * Adds an item to a set.
+ * For performance, this expression mutates its input during evaluation.
+ */
+case class AddItemToSet(item: Expression, set: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ def children = item :: set :: Nil
+
+ def nullable = set.nullable
+
+ def dataType = set.dataType
+
+ def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet
+
+ def eval(input: Row): Any = {
+ val itemEval = item.eval(input)
+ val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
+
+ if (itemEval != null) {
+ if (setEval != null) {
+ setEval.add(itemEval)
+ setEval
+ } else {
+ null
+ }
+ } else {
+ setEval
+ }
+ }
+
+ override def toString = s"$set += $item"
+}
+
+/**
+ * Combines the elements of two sets.
+ * For performance, this expression mutates its left input set during evaluation.
+ */
+case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
+ type EvaluatedType = Any
+
+ def nullable = left.nullable || right.nullable
+
+ def dataType = left.dataType
+
+ def symbol = "++="
+
+ def eval(input: Row): Any = {
+ val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if(leftEval != null) {
+ val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if (rightEval != null) {
+ val iterator = rightEval.iterator
+ while(iterator.hasNext) {
+ val rightValue = iterator.next()
+ leftEval.add(rightValue)
+ }
+ leftEval
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ }
+}
+
+/**
+ * Returns the number of elements in the input set.
+ */
+case class CountSet(child: Expression) extends UnaryExpression {
+ type EvaluatedType = Any
+
+ def nullable = child.nullable
+
+ def dataType = LongType
+
+ def eval(input: Row): Any = {
+ val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]]
+ if (childEval != null) {
+ childEval.size.toLong
+ }
+ }
+
+ override def toString = s"$child.count()"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index cd04bdf02c..96ce35939e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -280,7 +280,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
try {
- val defaultCtor = getClass.getConstructors.head
+ // Skip no-arg constructors that are just there for kryo.
+ val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 999c9fff38..f1df817c41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -136,6 +136,16 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("MaxOf") {
+ checkEvaluation(MaxOf(1, 2), 2)
+ checkEvaluation(MaxOf(2, 1), 2)
+ checkEvaluation(MaxOf(1L, 2L), 2L)
+ checkEvaluation(MaxOf(2L, 1L), 2L)
+
+ checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2)
+ checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 463a1d32d7..be9f155253 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -175,7 +175,7 @@ case class Aggregate(
private[this] val resultProjection =
new InterpretedMutableProjection(
resultExpressions, computedSchema ++ namedGroups.map(_._2))
- private[this] val joinedRow = new JoinedRow
+ private[this] val joinedRow = new JoinedRow4
override final def hasNext: Boolean = hashTableIter.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 4a26934c49..31ad5e8aab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -103,6 +103,40 @@ case class GeneratedAggregate(
updateCount :: updateSum :: Nil,
result
)
+
+ case m @ Max(expr) =>
+ val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
+ val initialValue = Literal(null, expr.dataType)
+ val updateMax = MaxOf(currentMax, expr)
+
+ AggregateEvaluation(
+ currentMax :: Nil,
+ initialValue :: Nil,
+ updateMax :: Nil,
+ currentMax)
+
+ case CollectHashSet(Seq(expr)) =>
+ val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)()
+ val initialValue = NewSet(expr.dataType)
+ val addToSet = AddItemToSet(expr, set)
+
+ AggregateEvaluation(
+ set :: Nil,
+ initialValue :: Nil,
+ addToSet :: Nil,
+ set)
+
+ case CombineSetsAndCount(inputSet) =>
+ val ArrayType(inputType, _) = inputSet.dataType
+ val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)()
+ val initialValue = NewSet(inputType)
+ val collectSets = CombineSets(set, inputSet)
+
+ AggregateEvaluation(
+ set :: Nil,
+ initialValue :: Nil,
+ collectSets :: Nil,
+ CountSet(set))
}
val computationSchema = computeFunctions.flatMap(_.schema)
@@ -151,7 +185,7 @@ case class GeneratedAggregate(
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
- val joinedRow = new JoinedRow
+ val joinedRow = new JoinedRow3
if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 34654447a5..077e6ebc5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -28,9 +28,13 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
+
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = new Kryo()
@@ -41,6 +45,13 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
+
+ // Specific hashsets must come first TODO: Move to core.
+ kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
+ kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
+ kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
+ new OpenHashSetSerializer)
+
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
new AllScalaRegistrar().apply(kryo)
@@ -109,3 +120,78 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
HyperLogLog.Builder.build(bytes)
}
}
+
+private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
+ def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
+ val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val row = iterator.next()
+ rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
+ val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
+ val numItems = input.readInt()
+ val set = new OpenHashSet[Any](numItems + 1)
+ var i = 0
+ while (i < numItems) {
+ val row =
+ new GenericRow(rowSerializer.read(
+ kryo,
+ input,
+ classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
+ set.add(row)
+ i += 1
+ }
+ set
+ }
+}
+
+private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
+ def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val value: Int = iterator.next()
+ output.writeInt(value)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
+ val numItems = input.readInt()
+ val set = new IntegerHashSet
+ var i = 0
+ while (i < numItems) {
+ val value = input.readInt()
+ set.add(value)
+ i += 1
+ }
+ set
+ }
+}
+
+private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
+ def write(kryo: Kryo, output: Output, hs: LongHashSet) {
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val value = iterator.next()
+ output.writeLong(value)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
+ val numItems = input.readInt()
+ val set = new LongHashSet
+ var i = 0
+ while (i < numItems) {
+ val value = input.readLong()
+ set.add(value)
+ i += 1
+ }
+ set
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f0c958fdb5..517b77804a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.parquet._
@@ -148,7 +149,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
- case _: Sum | _: Count => false
+ case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+ // The generated set implementation is pretty limited ATM.
+ case CollectHashSet(exprs) if exprs.size == 1 &&
+ Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
case _ => true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index b08f9aacc1..2890a563be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -92,7 +92,7 @@ trait HashJoin {
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
- private[this] val joinRow = new JoinedRow
+ private[this] val joinRow = new JoinedRow2
private[this] val joinKeys = streamSideKeyGenerator()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 0a3b59cbc2..ef4526ec03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -23,7 +23,7 @@ import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
import parquet.schema.MessageType
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
/**
@@ -278,14 +278,14 @@ private[parquet] class CatalystGroupConverter(
*/
private[parquet] class CatalystPrimitiveRowConverter(
protected[parquet] val schema: Array[FieldType],
- protected[parquet] var current: ParquetRelation.RowType)
+ protected[parquet] var current: MutableRow)
extends CatalystConverter {
// This constructor is used for the root converter only
def this(attributes: Array[Attribute]) =
this(
attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)),
- new ParquetRelation.RowType(attributes.length))
+ new SpecificMutableRow(attributes.map(_.dataType)))
protected [parquet] val converters: Array[Converter] =
schema.zipWithIndex.map {
@@ -299,7 +299,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
override val parent = null
// Should be only called in root group converter!
- override def getCurrentRecord: ParquetRelation.RowType = current
+ override def getCurrentRecord: Row = current
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f6cfab736d..a5a5d139a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -139,7 +139,7 @@ case class ParquetTableScan(
partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
new Iterator[Row] {
- private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null)
+ private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
def hasNext = iter.hasNext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 76b1724471..37d64f0de7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -45,16 +45,16 @@ class PlannerSuite extends FunSuite {
assert(aggregations.size === 2)
}
- test("count distinct is not partially aggregated") {
+ test("count distinct is partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = HashAggregation(query)
- assert(planned.isEmpty)
+ assert(planned.nonEmpty)
}
- test("mixed aggregates are not partially aggregated") {
+ test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = HashAggregation(query)
- assert(planned.isEmpty)
+ assert(planned.nonEmpty)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 3b371211e1..6571c35499 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -265,9 +265,9 @@ private[hive] case class MetastoreRelation
// org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException
// which indicates the SerDe we used is not Serializable.
- @transient lazy val hiveQlTable = new Table(table)
+ @transient val hiveQlTable = new Table(table)
- def hiveQlPartitions = partitions.map { p =>
+ @transient val hiveQlPartitions = partitions.map { p =>
new Partition(hiveQlTable, p)
}
diff --git a/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc
new file mode 100644
index 0000000000..573541ac97
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d
new file mode 100644
index 0000000000..0cfbf08886
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d
@@ -0,0 +1 @@
+2
diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1
new file mode 100644
index 0000000000..0cfbf08886
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1
@@ -0,0 +1 @@
+2
diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e
new file mode 100644
index 0000000000..0cfbf08886
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e
@@ -0,0 +1 @@
+2
diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 7c82964b5e..8d6ca9939a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive
+import org.scalatest.BeforeAndAfterAll
+
import scala.reflect.ClassTag
@@ -26,7 +28,9 @@ import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-class StatisticsSuite extends QueryTest {
+class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
+ TestHive.reset()
+ TestHive.cacheTables = false
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
@@ -126,7 +130,7 @@ class StatisticsSuite extends QueryTest {
val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation =>
mr.statistics.sizeInBytes
}
- assert(sizes.size === 1)
+ assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}")
assert(sizes(0).equals(BigInt(5812)),
s"expected exact size 5812 for test table 'src', got: ${sizes(0)}")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index fdb2f41f5a..26e4ec6e6d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -32,6 +32,71 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {
+ createQueryTest("count distinct 0 values",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 'a' AS a FROM src LIMIT 0) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 1 value strings",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 'b' AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 1 value",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 1 AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 2 values",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 2 AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 2 values including null",
+ """
+ |SELECT COUNT(DISTINCT a, 1) FROM (
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT null AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 1 value + null",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 1 AS a FROM src LIMIT 1 UNION ALL
+ | SELECT null AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 1 value long",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1L AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 1L AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 2 values long",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1L AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 2L AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
+ createQueryTest("count distinct 1 value + null long",
+ """
+ |SELECT COUNT(DISTINCT a) FROM (
+ | SELECT 1L AS a FROM src LIMIT 1 UNION ALL
+ | SELECT 1L AS a FROM src LIMIT 1 UNION ALL
+ | SELECT null AS a FROM src LIMIT 1) table
+ """.stripMargin)
+
createQueryTest("null case",
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index df9bae9649..8bc72384a6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -17,10 +17,19 @@
package org.apache.spark.sql.hive.execution
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.hive.test.TestHive
+
/**
* A set of tests that validates support for Hive SerDe.
*/
-class HiveSerDeSuite extends HiveComparisonTest {
+class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
+
+ override def beforeAll() = {
+ TestHive.cacheTables = false
+ }
+
createQueryTest(
"Read and write with LazySimpleSerDe (tab separated)",
"SELECT * from serdeins")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index 1a6dbc0ce0..8275e2d3bc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.execution
+import org.scalatest.BeforeAndAfter
+
import org.apache.spark.sql.hive.test.TestHive
/* Implicit conversions */
@@ -25,9 +27,10 @@ import scala.collection.JavaConversions._
/**
* A set of test cases that validate partition and column pruning.
*/
-class PruningSuite extends HiveComparisonTest {
+class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
// MINOR HACK: You must run a query before calling reset the first time.
TestHive.sql("SHOW TABLES")
+ TestHive.cacheTables = false
// Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset
// the environment to ensure all referenced tables in this suites are not cached in-memory.