aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala41
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala434
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala131
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala4
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala13
17 files changed, 446 insertions, 359 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
new file mode 100644
index 0000000000..8364379644
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+/**
+ * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values
+ * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization
+ * of the name, or the expected nullability).
+ */
+object AttributeMap {
+ def apply[A](kvs: Seq[(Attribute, A)]) =
+ new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+}
+
+class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
+ extends Map[Attribute, A] with Serializable {
+
+ override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
+
+ override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
+ (baseMap.map(_._2) + kv).toMap
+
+ override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+
+ override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 54c6baf1af..fa80b07f8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
}
object BindReferences extends Logging {
- def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
+
+ def bindReference[A <: Expression](
+ expression: A,
+ input: Seq[Attribute],
+ allowFailures: Boolean = false): A = {
expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId)
if (ordinal == -1) {
- sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+ if (allowFailures) {
+ a
+ } else {
+ sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+ }
} else {
BoundReference(ordinal, a.dataType, a.nullable)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 64d49354da..4137ac7663 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -26,6 +26,7 @@ import java.util.Properties
private[spark] object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
+ val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
@@ -124,6 +125,12 @@ trait SQLConf {
private[spark] def isParquetBinaryAsString: Boolean =
getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean
+ /**
+ * When set to true, partition pruning for in-memory columnar tables is enabled.
+ */
+ private[spark] def inMemoryPartitionPruning: Boolean =
+ getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 247337a875..b3ec5ded22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder {
/**
* Column statistics information
*/
- def columnStats: ColumnStats[_, _]
+ def columnStats: ColumnStats
/**
* Returns the final columnar byte buffer.
@@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder {
}
private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
- val columnStats: ColumnStats[T, JvmType],
+ val columnStats: ColumnStats,
val columnType: ColumnType[T, JvmType])
extends ColumnBuilder {
@@ -81,18 +81,18 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
with NullableColumnBuilder
private[sql] abstract class NativeColumnBuilder[T <: NativeType](
- override val columnStats: NativeColumnStats[T],
+ override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T])
extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
with NullableColumnBuilder
with AllCompressionSchemes
with CompressibleColumnBuilder[T]
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN)
private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 6502110e90..fc343ccb99 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,381 +17,193 @@
package org.apache.spark.sql.columnar
+import java.sql.Timestamp
+
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.types._
+private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
+ val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
+ val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
+ val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+
+ val schema = Seq(lowerBound, upperBound, nullCount)
+}
+
+private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
+ val (forAttribute, schema) = {
+ val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a))
+ (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _))
+ }
+}
+
/**
* Used to collect statistical information when building in-memory columns.
*
* NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]`
* brings significant performance penalty.
*/
-private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
- /**
- * Closed lower bound of this column.
- */
- def lowerBound: JvmType
-
- /**
- * Closed upper bound of this column.
- */
- def upperBound: JvmType
-
+private[sql] sealed trait ColumnStats extends Serializable {
/**
* Gathers statistics information from `row(ordinal)`.
*/
- def gatherStats(row: Row, ordinal: Int)
-
- /**
- * Returns `true` if `lower <= row(ordinal) <= upper`.
- */
- def contains(row: Row, ordinal: Int): Boolean
+ def gatherStats(row: Row, ordinal: Int): Unit
/**
- * Returns `true` if `row(ordinal) < upper` holds.
+ * Column statistics represented as a single row, currently including closed lower bound, closed
+ * upper bound and null count.
*/
- def isAbove(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `lower < row(ordinal)` holds.
- */
- def isBelow(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `row(ordinal) <= upper` holds.
- */
- def isAtOrAbove(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `lower <= row(ordinal)` holds.
- */
- def isAtOrBelow(row: Row, ordinal: Int): Boolean
-}
-
-private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
- extends ColumnStats[T, T#JvmType] {
-
- type JvmType = T#JvmType
-
- protected var (_lower, _upper) = initialBounds
-
- def initialBounds: (JvmType, JvmType)
-
- protected def columnType: NativeColumnType[T]
-
- override def lowerBound: T#JvmType = _lower
-
- override def upperBound: T#JvmType = _upper
-
- override def isAtOrAbove(row: Row, ordinal: Int) = {
- contains(row, ordinal) || isAbove(row, ordinal)
- }
-
- override def isAtOrBelow(row: Row, ordinal: Int) = {
- contains(row, ordinal) || isBelow(row, ordinal)
- }
+ def collectedStatistics: Row
}
-private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
- override def isAtOrBelow(row: Row, ordinal: Int) = true
-
- override def isAtOrAbove(row: Row, ordinal: Int) = true
-
- override def isBelow(row: Row, ordinal: Int) = true
-
- override def isAbove(row: Row, ordinal: Int) = true
+private[sql] class NoopColumnStats extends ColumnStats {
- override def contains(row: Row, ordinal: Int) = true
+ override def gatherStats(row: Row, ordinal: Int): Unit = {}
- override def gatherStats(row: Row, ordinal: Int) {}
-
- override def upperBound = null.asInstanceOf[JvmType]
-
- override def lowerBound = null.asInstanceOf[JvmType]
+ override def collectedStatistics = Row()
}
-private[sql] abstract class BasicColumnStats[T <: NativeType](
- protected val columnType: NativeColumnType[T])
- extends NativeColumnStats[T]
-
-private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
- override def initialBounds = (true, false)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class ByteColumnStats extends ColumnStats {
+ var upper = Byte.MinValue
+ var lower = Byte.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
- override def initialBounds = (Byte.MaxValue, Byte.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getByte(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
- override def initialBounds = (Short.MaxValue, Short.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class ShortColumnStats extends ColumnStats {
+ var upper = Short.MinValue
+ var lower = Short.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
- override def initialBounds = (Long.MaxValue, Long.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getShort(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
- override def initialBounds = (Double.MaxValue, Double.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class LongColumnStats extends ColumnStats {
+ var upper = Long.MinValue
+ var lower = Long.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
- override def initialBounds = (Float.MaxValue, Float.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getLong(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
+}
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class DoubleColumnStats extends ColumnStats {
+ var upper = Double.MinValue
+ var lower = Double.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getDouble(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
-}
-private[sql] object IntColumnStats {
- val UNINITIALIZED = 0
- val INITIALIZED = 1
- val ASCENDING = 2
- val DESCENDING = 3
- val UNORDERED = 4
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-/**
- * Statistical information for `Int` columns. More information is collected since `Int` is
- * frequently used. Extra information include:
- *
- * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
- * is applicable when searching elements.
- * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
- * scheme.
- *
- * (This two kinds of information are not used anywhere yet and might be removed later.)
- */
-private[sql] class IntColumnStats extends BasicColumnStats(INT) {
- import IntColumnStats._
-
- private var orderedState = UNINITIALIZED
- private var lastValue: Int = _
- private var _maxDelta: Int = _
-
- def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
- def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
- def isOrdered = isAscending || isDescending
- def maxDelta = _maxDelta
-
- override def initialBounds = (Int.MaxValue, Int.MinValue)
+private[sql] class FloatColumnStats extends ColumnStats {
+ var upper = Float.MinValue
+ var lower = Float.MaxValue
+ var nullCount = 0
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
+ override def gatherStats(row: Row, ordinal: Int) {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getFloat(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
+}
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class IntColumnStats extends ColumnStats {
+ var upper = Int.MinValue
+ var lower = Int.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
-
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
-
- orderedState = orderedState match {
- case UNINITIALIZED =>
- lastValue = field
- INITIALIZED
-
- case INITIALIZED =>
- // If all the integers in the column are the same, ordered state is set to Ascending.
- // TODO (lian) Confirm whether this is the standard behaviour.
- val nextState = if (field >= lastValue) ASCENDING else DESCENDING
- _maxDelta = math.abs(field - lastValue)
- lastValue = field
- nextState
-
- case ASCENDING if field < lastValue =>
- UNORDERED
-
- case DESCENDING if field > lastValue =>
- UNORDERED
-
- case state @ (ASCENDING | DESCENDING) =>
- _maxDelta = _maxDelta.max(field - lastValue)
- lastValue = field
- state
-
- case _ =>
- orderedState
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getInt(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
}
}
+
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
- override def initialBounds = (null, null)
+private[sql] class StringColumnStats extends ColumnStats {
+ var upper: String = null
+ var lower: String = null
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
- if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
- }
-
- override def contains(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
- }
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- field.compareTo(upperBound) < 0
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getString(ordinal)
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ } else {
+ nullCount += 1
}
}
- override def isBelow(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) < 0
- }
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) {
- override def initialBounds = (null, null)
+private[sql] class TimestampColumnStats extends ColumnStats {
+ var upper: Timestamp = null
+ var lower: Timestamp = null
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
- if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
- }
-
- override def contains(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+ if (!row.isNullAt(ordinal)) {
+ val value = row(ordinal).asInstanceOf[Timestamp]
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ } else {
+ nullCount += 1
}
}
- override def isAbove(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- field.compareTo(upperBound) < 0
- }
- }
-
- override def isBelow(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) < 0
- }
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index cb055cd74a..dc668e7dc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
@@ -31,23 +33,27 @@ object InMemoryRelation {
new InMemoryRelation(child.output, useCompression, batchSize, child)()
}
+private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
batchSize: Int,
child: SparkPlan)
- (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
+ (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
extends LogicalPlan with MultiInstanceRelation {
override lazy val statistics =
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+ val partitionStatistics = new PartitionStatistics(output)
+
// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
val output = child.output
val cached = child.execute().mapPartitions { baseIterator =>
- new Iterator[Array[ByteBuffer]] {
+ new Iterator[CachedBatch] {
def next() = {
val columnBuilders = output.map { attribute =>
val columnType = ColumnType(attribute.dataType)
@@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- columnBuilders.map(_.build())
+ val stats = Row.fromSeq(
+ columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+
+ CachedBatch(columnBuilders.map(_.build()), stats)
}
def hasNext = baseIterator.hasNext
@@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation(
_cachedColumnBuffers = cached
}
-
override def children = Seq.empty
override def newInstance() = {
@@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation(
private[sql] case class InMemoryColumnarTableScan(
attributes: Seq[Attribute],
+ predicates: Seq[Expression],
relation: InMemoryRelation)
extends LeafNode {
+ @transient override val sqlContext = relation.child.sqlContext
+
override def output: Seq[Attribute] = attributes
+ // Returned filter predicate should return false iff it is impossible for the input expression
+ // to evaluate to `true' based on statistics collected about this partition batch.
+ val buildFilter: PartialFunction[Expression, Expression] = {
+ case And(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) && buildFilter(rhs)
+
+ case Or(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) || buildFilter(rhs)
+
+ case EqualTo(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l && l <= aStats.upperBound
+
+ case EqualTo(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l && l <= aStats.upperBound
+
+ case LessThan(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound < l
+
+ case LessThan(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l < aStats.upperBound
+
+ case LessThanOrEqual(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l
+
+ case LessThanOrEqual(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l <= aStats.upperBound
+
+ case GreaterThan(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l < aStats.upperBound
+
+ case GreaterThan(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound < l
+
+ case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l <= aStats.upperBound
+
+ case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l
+ }
+
+ val partitionFilters = {
+ predicates.flatMap { p =>
+ val filter = buildFilter.lift(p)
+ val boundFilter =
+ filter.map(
+ BindReferences.bindReference(
+ _,
+ relation.partitionStatistics.schema,
+ allowFailures = true))
+
+ boundFilter.foreach(_ =>
+ filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
+
+ // If the filter can't be resolved then we are missing required statistics.
+ boundFilter.filter(_.resolved)
+ }
+ }
+
+ val readPartitions = sparkContext.accumulator(0)
+ val readBatches = sparkContext.accumulator(0)
+
+ private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning
+
override def execute() = {
+ readPartitions.setValue(0)
+ readBatches.setValue(0)
+
relation.cachedColumnBuffers.mapPartitions { iterator =>
+ val partitionFilter = newPredicate(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ relation.partitionStatistics.schema)
+
// Find the ordinals of the requested columns. If none are requested, use the first.
val requestedColumns = if (attributes.isEmpty) {
Seq(0)
@@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan(
attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
}
- iterator
- .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+ val rows = iterator
+ // Skip pruned batches
+ .filter { cachedBatch =>
+ if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
+ def statsString = relation.partitionStatistics.schema
+ .zip(cachedBatch.stats)
+ .map { case (a, s) => s"${a.name}: $s" }
+ .mkString(", ")
+ logInfo(s"Skipping partition based on stats $statsString")
+ false
+ } else {
+ readBatches += 1
+ true
+ }
+ }
+ // Build column accessors
+ .map { cachedBatch =>
+ requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+ }
+ // Extract rows via column accessors
.flatMap { columnAccessors =>
val nextRow = new GenericMutableRow(columnAccessors.length)
new Iterator[Row] {
@@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan(
override def hasNext = columnAccessors.head.hasNext
}
}
+
+ if (rows.hasNext) {
+ readPartitions += 1
+ }
+
+ rows
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index f631ee76fc..a72970eef7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
}
abstract override def appendFrom(row: Row, ordinal: Int) {
+ columnStats.gatherStats(row, ordinal)
if (row.isNullAt(ordinal)) {
nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
nulls.putInt(pos)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8dacb84c8a..7943d6e1b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
pruneFilterProject(
projectList,
filters,
- identity[Seq[Expression]], // No filters are pushed down.
- InMemoryColumnarTableScan(_, mem)) :: Nil
+ identity[Seq[Expression]], // All filters still need to be evaluated.
+ InMemoryColumnarTableScan(_, filters, mem)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 5f61fb5e16..cde91ceb68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._
class ColumnStatsSuite extends FunSuite {
- testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
- testColumnStats(classOf[ByteColumnStats], BYTE)
- testColumnStats(classOf[ShortColumnStats], SHORT)
- testColumnStats(classOf[IntColumnStats], INT)
- testColumnStats(classOf[LongColumnStats], LONG)
- testColumnStats(classOf[FloatColumnStats], FLOAT)
- testColumnStats(classOf[DoubleColumnStats], DOUBLE)
- testColumnStats(classOf[StringColumnStats], STRING)
- testColumnStats(classOf[TimestampColumnStats], TIMESTAMP)
-
- def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+ testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
+ testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
+
+ def testColumnStats[T <: NativeType, U <: ColumnStats](
columnStatsClass: Class[U],
- columnType: NativeColumnType[T]) {
+ columnType: NativeColumnType[T],
+ initialStatistics: Row) {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- assertResult(columnStats.initialBounds, "Wrong initial bounds") {
- (columnStats.lowerBound, columnStats.upperBound)
+ columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) =>
+ assert(actual === expected)
}
}
@@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite {
import ColumnarTestUtils._
val columnStats = columnStatsClass.newInstance()
- val rows = Seq.fill(10)(makeRandomRow(columnType))
+ val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.map(_.head.asInstanceOf[T#JvmType])
+ val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+ val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
- assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+ assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+ assertResult(10, "Wrong null count")(stats(2))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index dc813fe146..a77262534a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
with NullableColumnBuilder
object TestNullableColumnBuilder {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
new file mode 100644
index 0000000000..5d2fd49591
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class IntegerData(i: Int)
+
+class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
+ val originalColumnBatchSize = columnBatchSize
+ val originalInMemoryPartitionPruning = inMemoryPartitionPruning
+
+ override protected def beforeAll() {
+ // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
+ setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+ val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
+ rawData.registerTempTable("intData")
+
+ // Enable in-memory partition pruning
+ setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ }
+
+ override protected def afterAll() {
+ setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+ setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ }
+
+ before {
+ cacheTable("intData")
+ }
+
+ after {
+ uncacheTable("intData")
+ }
+
+ // Comparisons
+ checkBatchPruning("i = 1", Seq(1), 1, 1)
+ checkBatchPruning("1 = i", Seq(1), 1, 1)
+ checkBatchPruning("i < 12", 1 to 11, 1, 2)
+ checkBatchPruning("i <= 11", 1 to 11, 1, 2)
+ checkBatchPruning("i > 88", 89 to 100, 1, 2)
+ checkBatchPruning("i >= 89", 89 to 100, 1, 2)
+ checkBatchPruning("12 > i", 1 to 11, 1, 2)
+ checkBatchPruning("11 >= i", 1 to 11, 1, 2)
+ checkBatchPruning("88 < i", 89 to 100, 1, 2)
+ checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+
+ // Conjunction and disjunction
+ checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
+ checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
+ checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
+
+ // With unsupported predicate
+ checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
+ checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
+
+ def checkBatchPruning(
+ filter: String,
+ expectedQueryResult: Seq[Int],
+ expectedReadPartitions: Int,
+ expectedReadBatches: Int) {
+
+ test(filter) {
+ val query = sql(s"SELECT * FROM intData WHERE $filter")
+ assertResult(expectedQueryResult.toArray, "Wrong query result") {
+ query.collect().map(_.head).toArray
+ }
+
+ val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
+ case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
+ }.head
+
+ assert(readBatches === expectedReadBatches, "Wrong number of read batches")
+ assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 5fba004809..e01cc8b4d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
import org.apache.spark.sql.Row
-import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats}
+import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
class BooleanBitSetSuite extends FunSuite {
@@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite {
// Tests encoder
// -------------
- val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet)
+ val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
val values = rows.map(_.head)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d8ae2a2677..d2969d906c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite {
testDictionaryEncoding(new StringColumnStats, STRING)
def testDictionaryEncoding[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 17619dcf97..322f447c24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite {
testIntegralDelta(new LongColumnStats, LONG, LongDelta)
def testIntegralDelta[I <: IntegralType](
- columnStats: NativeColumnStats[I],
+ columnStats: ColumnStats,
columnType: NativeColumnType[I],
scheme: IntegralDelta[I]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 40115beb98..218c09ac26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
class RunLengthEncodingSuite extends FunSuite {
- testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+ testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
@@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new StringColumnStats, STRING)
def testRunLengthEncoding[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index 72c19fa31d..7db723d648 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
class TestCompressibleColumnBuilder[T <: NativeType](
- override val columnStats: NativeColumnStats[T],
+ override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T],
override val schemes: Seq[CompressionScheme])
extends NativeColumnBuilder(columnStats, columnType)
@@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType](
object TestCompressibleColumnBuilder {
def apply[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T],
scheme: CompressionScheme) = {
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index b589994bd2..ab487d673e 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -35,26 +35,29 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalTimeZone = TimeZone.getDefault
private val originalLocale = Locale.getDefault
- private val originalUseCompression = TestHive.useCompression
+ private val originalColumnBatchSize = TestHive.columnBatchSize
+ private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning
def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
override def beforeAll() {
- // Enable in-memory columnar caching
TestHive.cacheTables = true
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
// Add Locale setting
Locale.setDefault(Locale.US)
- // Enable in-memory columnar compression
- TestHive.setConf(SQLConf.COMPRESS_CACHED, "true")
+ // Set a relatively small column batch size for testing purposes
+ TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5")
+ // Enable in-memory partition pruning for testing purposes
+ TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
}
override def afterAll() {
TestHive.cacheTables = false
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
- TestHive.setConf(SQLConf.COMPRESS_CACHED, originalUseCompression.toString)
+ TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+ TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
/** A list of tests deemed out of scope currently and thus completely disregarded. */