aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-09-03 18:59:26 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-03 18:59:26 -0700
commit248067adbe90f93c7d5e23aa61b3072dfdf48a8a (patch)
tree8662a77cbd0847d63d43dc69fae11c16dcc71388 /sql/core
parentf48420fde58d554480cc8830d2f8c4d17618f283 (diff)
downloadspark-248067adbe90f93c7d5e23aa61b3072dfdf48a8a.tar.gz
spark-248067adbe90f93c7d5e23aa61b3072dfdf48a8a.tar.bz2
spark-248067adbe90f93c7d5e23aa61b3072dfdf48a8a.zip
[SPARK-2961][SQL] Use statistics to prune batches within cached partitions
This PR is based on #1883 authored by marmbrus. Key differences: 1. Batch pruning instead of partition pruning When #1883 was authored, batched column buffer building (#1880) hadn't been introduced. This PR combines these two and provide partition batch level pruning, which leads to smaller memory footprints and can generally skip more elements. The cost is that the pruning predicates are evaluated more frequently (partition number multiplies batch number per partition). 1. More filters are supported Filter predicates consist of `=`, `<`, `<=`, `>`, `>=` and their conjunctions and disjunctions are supported. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2188 from liancheng/in-mem-batch-pruning and squashes the following commits: 68cf019 [Cheng Lian] Marked sqlContext as @transient 4254f6c [Cheng Lian] Enables in-memory partition pruning in PartitionBatchPruningSuite 3784105 [Cheng Lian] Overrides InMemoryColumnarTableScan.sqlContext d2a1d66 [Cheng Lian] Disables in-memory partition pruning by default 062c315 [Cheng Lian] HiveCompatibilitySuite code cleanup 16b77bf [Cheng Lian] Fixed pruning predication conjunctions and disjunctions 16195c5 [Cheng Lian] Enabled both disjunction and conjunction 89950d0 [Cheng Lian] Worked around Scala style check 9c167f6 [Cheng Lian] Minor code cleanup 3c4d5c7 [Cheng Lian] Minor code cleanup ea59ee5 [Cheng Lian] Renamed PartitionSkippingSuite to PartitionBatchPruningSuite fc517d0 [Cheng Lian] More test cases 1868c18 [Cheng Lian] Code cleanup, bugfix, and adding tests cb76da4 [Cheng Lian] Added more predicate filters, fixed table scan stats for testing purposes 385474a [Cheng Lian] Merge branch 'inMemStats' into in-mem-batch-pruning
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala434
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala131
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala4
14 files changed, 387 insertions, 352 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 64d49354da..4137ac7663 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -26,6 +26,7 @@ import java.util.Properties
private[spark] object SQLConf {
val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed"
val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize"
+ val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning"
val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold"
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
@@ -124,6 +125,12 @@ trait SQLConf {
private[spark] def isParquetBinaryAsString: Boolean =
getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean
+ /**
+ * When set to true, partition pruning for in-memory columnar tables is enabled.
+ */
+ private[spark] def inMemoryPartitionPruning: Boolean =
+ getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 247337a875..b3ec5ded22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder {
/**
* Column statistics information
*/
- def columnStats: ColumnStats[_, _]
+ def columnStats: ColumnStats
/**
* Returns the final columnar byte buffer.
@@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder {
}
private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
- val columnStats: ColumnStats[T, JvmType],
+ val columnStats: ColumnStats,
val columnType: ColumnType[T, JvmType])
extends ColumnBuilder {
@@ -81,18 +81,18 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
with NullableColumnBuilder
private[sql] abstract class NativeColumnBuilder[T <: NativeType](
- override val columnStats: NativeColumnStats[T],
+ override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T])
extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
with NullableColumnBuilder
with AllCompressionSchemes
with CompressibleColumnBuilder[T]
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN)
private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 6502110e90..fc343ccb99 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -17,381 +17,193 @@
package org.apache.spark.sql.columnar
+import java.sql.Timestamp
+
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.types._
+private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
+ val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
+ val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
+ val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+
+ val schema = Seq(lowerBound, upperBound, nullCount)
+}
+
+private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
+ val (forAttribute, schema) = {
+ val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a))
+ (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _))
+ }
+}
+
/**
* Used to collect statistical information when building in-memory columns.
*
* NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]`
* brings significant performance penalty.
*/
-private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
- /**
- * Closed lower bound of this column.
- */
- def lowerBound: JvmType
-
- /**
- * Closed upper bound of this column.
- */
- def upperBound: JvmType
-
+private[sql] sealed trait ColumnStats extends Serializable {
/**
* Gathers statistics information from `row(ordinal)`.
*/
- def gatherStats(row: Row, ordinal: Int)
-
- /**
- * Returns `true` if `lower <= row(ordinal) <= upper`.
- */
- def contains(row: Row, ordinal: Int): Boolean
+ def gatherStats(row: Row, ordinal: Int): Unit
/**
- * Returns `true` if `row(ordinal) < upper` holds.
+ * Column statistics represented as a single row, currently including closed lower bound, closed
+ * upper bound and null count.
*/
- def isAbove(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `lower < row(ordinal)` holds.
- */
- def isBelow(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `row(ordinal) <= upper` holds.
- */
- def isAtOrAbove(row: Row, ordinal: Int): Boolean
-
- /**
- * Returns `true` if `lower <= row(ordinal)` holds.
- */
- def isAtOrBelow(row: Row, ordinal: Int): Boolean
-}
-
-private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
- extends ColumnStats[T, T#JvmType] {
-
- type JvmType = T#JvmType
-
- protected var (_lower, _upper) = initialBounds
-
- def initialBounds: (JvmType, JvmType)
-
- protected def columnType: NativeColumnType[T]
-
- override def lowerBound: T#JvmType = _lower
-
- override def upperBound: T#JvmType = _upper
-
- override def isAtOrAbove(row: Row, ordinal: Int) = {
- contains(row, ordinal) || isAbove(row, ordinal)
- }
-
- override def isAtOrBelow(row: Row, ordinal: Int) = {
- contains(row, ordinal) || isBelow(row, ordinal)
- }
+ def collectedStatistics: Row
}
-private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
- override def isAtOrBelow(row: Row, ordinal: Int) = true
-
- override def isAtOrAbove(row: Row, ordinal: Int) = true
-
- override def isBelow(row: Row, ordinal: Int) = true
-
- override def isAbove(row: Row, ordinal: Int) = true
+private[sql] class NoopColumnStats extends ColumnStats {
- override def contains(row: Row, ordinal: Int) = true
+ override def gatherStats(row: Row, ordinal: Int): Unit = {}
- override def gatherStats(row: Row, ordinal: Int) {}
-
- override def upperBound = null.asInstanceOf[JvmType]
-
- override def lowerBound = null.asInstanceOf[JvmType]
+ override def collectedStatistics = Row()
}
-private[sql] abstract class BasicColumnStats[T <: NativeType](
- protected val columnType: NativeColumnType[T])
- extends NativeColumnStats[T]
-
-private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
- override def initialBounds = (true, false)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class ByteColumnStats extends ColumnStats {
+ var upper = Byte.MinValue
+ var lower = Byte.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
- override def initialBounds = (Byte.MaxValue, Byte.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getByte(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
- override def initialBounds = (Short.MaxValue, Short.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class ShortColumnStats extends ColumnStats {
+ var upper = Short.MinValue
+ var lower = Short.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
- override def initialBounds = (Long.MaxValue, Long.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getShort(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
- override def initialBounds = (Double.MaxValue, Double.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
-
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class LongColumnStats extends ColumnStats {
+ var upper = Long.MinValue
+ var lower = Long.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
- }
-}
-
-private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
- override def initialBounds = (Float.MaxValue, Float.MinValue)
-
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getLong(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
+}
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class DoubleColumnStats extends ColumnStats {
+ var upper = Double.MinValue
+ var lower = Double.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getDouble(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
-}
-private[sql] object IntColumnStats {
- val UNINITIALIZED = 0
- val INITIALIZED = 1
- val ASCENDING = 2
- val DESCENDING = 3
- val UNORDERED = 4
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-/**
- * Statistical information for `Int` columns. More information is collected since `Int` is
- * frequently used. Extra information include:
- *
- * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
- * is applicable when searching elements.
- * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
- * scheme.
- *
- * (This two kinds of information are not used anywhere yet and might be removed later.)
- */
-private[sql] class IntColumnStats extends BasicColumnStats(INT) {
- import IntColumnStats._
-
- private var orderedState = UNINITIALIZED
- private var lastValue: Int = _
- private var _maxDelta: Int = _
-
- def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
- def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
- def isOrdered = isAscending || isDescending
- def maxDelta = _maxDelta
-
- override def initialBounds = (Int.MaxValue, Int.MinValue)
+private[sql] class FloatColumnStats extends ColumnStats {
+ var upper = Float.MinValue
+ var lower = Float.MaxValue
+ var nullCount = 0
- override def isBelow(row: Row, ordinal: Int) = {
- lowerBound < columnType.getField(row, ordinal)
+ override def gatherStats(row: Row, ordinal: Int) {
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getFloat(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
+ }
}
- override def isAbove(row: Row, ordinal: Int) = {
- columnType.getField(row, ordinal) < upperBound
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
+}
- override def contains(row: Row, ordinal: Int) = {
- val field = columnType.getField(row, ordinal)
- lowerBound <= field && field <= upperBound
- }
+private[sql] class IntColumnStats extends ColumnStats {
+ var upper = Int.MinValue
+ var lower = Int.MaxValue
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
-
- if (field > upperBound) _upper = field
- if (field < lowerBound) _lower = field
-
- orderedState = orderedState match {
- case UNINITIALIZED =>
- lastValue = field
- INITIALIZED
-
- case INITIALIZED =>
- // If all the integers in the column are the same, ordered state is set to Ascending.
- // TODO (lian) Confirm whether this is the standard behaviour.
- val nextState = if (field >= lastValue) ASCENDING else DESCENDING
- _maxDelta = math.abs(field - lastValue)
- lastValue = field
- nextState
-
- case ASCENDING if field < lastValue =>
- UNORDERED
-
- case DESCENDING if field > lastValue =>
- UNORDERED
-
- case state @ (ASCENDING | DESCENDING) =>
- _maxDelta = _maxDelta.max(field - lastValue)
- lastValue = field
- state
-
- case _ =>
- orderedState
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getInt(ordinal)
+ if (value > upper) upper = value
+ if (value < lower) lower = value
+ } else {
+ nullCount += 1
}
}
+
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
- override def initialBounds = (null, null)
+private[sql] class StringColumnStats extends ColumnStats {
+ var upper: String = null
+ var lower: String = null
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
- if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
- }
-
- override def contains(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
- }
- }
-
- override def isAbove(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- field.compareTo(upperBound) < 0
+ if (!row.isNullAt(ordinal)) {
+ val value = row.getString(ordinal)
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ } else {
+ nullCount += 1
}
}
- override def isBelow(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) < 0
- }
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
-private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) {
- override def initialBounds = (null, null)
+private[sql] class TimestampColumnStats extends ColumnStats {
+ var upper: Timestamp = null
+ var lower: Timestamp = null
+ var nullCount = 0
override def gatherStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
- if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
- }
-
- override def contains(row: Row, ordinal: Int) = {
- (upperBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+ if (!row.isNullAt(ordinal)) {
+ val value = row(ordinal).asInstanceOf[Timestamp]
+ if (upper == null || value.compareTo(upper) > 0) upper = value
+ if (lower == null || value.compareTo(lower) < 0) lower = value
+ } else {
+ nullCount += 1
}
}
- override def isAbove(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- field.compareTo(upperBound) < 0
- }
- }
-
- override def isBelow(row: Row, ordinal: Int) = {
- (lowerBound ne null) && {
- val field = columnType.getField(row, ordinal)
- lowerBound.compareTo(field) < 0
- }
- }
+ def collectedStatistics = Row(lower, upper, nullCount)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index cb055cd74a..dc668e7dc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
@@ -31,23 +33,27 @@ object InMemoryRelation {
new InMemoryRelation(child.output, useCompression, batchSize, child)()
}
+private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
useCompression: Boolean,
batchSize: Int,
child: SparkPlan)
- (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null)
+ (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
extends LogicalPlan with MultiInstanceRelation {
override lazy val statistics =
Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+ val partitionStatistics = new PartitionStatistics(output)
+
// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
val output = child.output
val cached = child.execute().mapPartitions { baseIterator =>
- new Iterator[Array[ByteBuffer]] {
+ new Iterator[CachedBatch] {
def next() = {
val columnBuilders = output.map { attribute =>
val columnType = ColumnType(attribute.dataType)
@@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation(
rowCount += 1
}
- columnBuilders.map(_.build())
+ val stats = Row.fromSeq(
+ columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+
+ CachedBatch(columnBuilders.map(_.build()), stats)
}
def hasNext = baseIterator.hasNext
@@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation(
_cachedColumnBuffers = cached
}
-
override def children = Seq.empty
override def newInstance() = {
@@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation(
private[sql] case class InMemoryColumnarTableScan(
attributes: Seq[Attribute],
+ predicates: Seq[Expression],
relation: InMemoryRelation)
extends LeafNode {
+ @transient override val sqlContext = relation.child.sqlContext
+
override def output: Seq[Attribute] = attributes
+ // Returned filter predicate should return false iff it is impossible for the input expression
+ // to evaluate to `true' based on statistics collected about this partition batch.
+ val buildFilter: PartialFunction[Expression, Expression] = {
+ case And(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) && buildFilter(rhs)
+
+ case Or(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) || buildFilter(rhs)
+
+ case EqualTo(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l && l <= aStats.upperBound
+
+ case EqualTo(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l && l <= aStats.upperBound
+
+ case LessThan(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound < l
+
+ case LessThan(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l < aStats.upperBound
+
+ case LessThanOrEqual(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l
+
+ case LessThanOrEqual(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l <= aStats.upperBound
+
+ case GreaterThan(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l < aStats.upperBound
+
+ case GreaterThan(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound < l
+
+ case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ l <= aStats.upperBound
+
+ case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
+ val aStats = relation.partitionStatistics.forAttribute(a)
+ aStats.lowerBound <= l
+ }
+
+ val partitionFilters = {
+ predicates.flatMap { p =>
+ val filter = buildFilter.lift(p)
+ val boundFilter =
+ filter.map(
+ BindReferences.bindReference(
+ _,
+ relation.partitionStatistics.schema,
+ allowFailures = true))
+
+ boundFilter.foreach(_ =>
+ filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
+
+ // If the filter can't be resolved then we are missing required statistics.
+ boundFilter.filter(_.resolved)
+ }
+ }
+
+ val readPartitions = sparkContext.accumulator(0)
+ val readBatches = sparkContext.accumulator(0)
+
+ private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning
+
override def execute() = {
+ readPartitions.setValue(0)
+ readBatches.setValue(0)
+
relation.cachedColumnBuffers.mapPartitions { iterator =>
+ val partitionFilter = newPredicate(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ relation.partitionStatistics.schema)
+
// Find the ordinals of the requested columns. If none are requested, use the first.
val requestedColumns = if (attributes.isEmpty) {
Seq(0)
@@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan(
attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
}
- iterator
- .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+ val rows = iterator
+ // Skip pruned batches
+ .filter { cachedBatch =>
+ if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
+ def statsString = relation.partitionStatistics.schema
+ .zip(cachedBatch.stats)
+ .map { case (a, s) => s"${a.name}: $s" }
+ .mkString(", ")
+ logInfo(s"Skipping partition based on stats $statsString")
+ false
+ } else {
+ readBatches += 1
+ true
+ }
+ }
+ // Build column accessors
+ .map { cachedBatch =>
+ requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+ }
+ // Extract rows via column accessors
.flatMap { columnAccessors =>
val nextRow = new GenericMutableRow(columnAccessors.length)
new Iterator[Row] {
@@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan(
override def hasNext = columnAccessors.head.hasNext
}
}
+
+ if (rows.hasNext) {
+ readPartitions += 1
+ }
+
+ rows
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index f631ee76fc..a72970eef7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
}
abstract override def appendFrom(row: Row, ordinal: Int) {
+ columnStats.gatherStats(row, ordinal)
if (row.isNullAt(ordinal)) {
nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
nulls.putInt(pos)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 8dacb84c8a..7943d6e1b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
pruneFilterProject(
projectList,
filters,
- identity[Seq[Expression]], // No filters are pushed down.
- InMemoryColumnarTableScan(_, mem)) :: Nil
+ identity[Seq[Expression]], // All filters still need to be evaluated.
+ InMemoryColumnarTableScan(_, filters, mem)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 5f61fb5e16..cde91ceb68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._
class ColumnStatsSuite extends FunSuite {
- testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
- testColumnStats(classOf[ByteColumnStats], BYTE)
- testColumnStats(classOf[ShortColumnStats], SHORT)
- testColumnStats(classOf[IntColumnStats], INT)
- testColumnStats(classOf[LongColumnStats], LONG)
- testColumnStats(classOf[FloatColumnStats], FLOAT)
- testColumnStats(classOf[DoubleColumnStats], DOUBLE)
- testColumnStats(classOf[StringColumnStats], STRING)
- testColumnStats(classOf[TimestampColumnStats], TIMESTAMP)
-
- def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+ testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0))
+ testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0))
+ testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0))
+ testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
+ testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
+ testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
+ testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
+
+ def testColumnStats[T <: NativeType, U <: ColumnStats](
columnStatsClass: Class[U],
- columnType: NativeColumnType[T]) {
+ columnType: NativeColumnType[T],
+ initialStatistics: Row) {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
- assertResult(columnStats.initialBounds, "Wrong initial bounds") {
- (columnStats.lowerBound, columnStats.upperBound)
+ columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) =>
+ assert(actual === expected)
}
}
@@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite {
import ColumnarTestUtils._
val columnStats = columnStatsClass.newInstance()
- val rows = Seq.fill(10)(makeRandomRow(columnType))
+ val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.map(_.head.asInstanceOf[T#JvmType])
+ val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+ val stats = columnStats.collectedStatistics
- assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
- assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+ assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+ assertResult(10, "Wrong null count")(stats(2))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index dc813fe146..a77262534a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
- extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType)
with NullableColumnBuilder
object TestNullableColumnBuilder {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
new file mode 100644
index 0000000000..5d2fd49591
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class IntegerData(i: Int)
+
+class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
+ val originalColumnBatchSize = columnBatchSize
+ val originalInMemoryPartitionPruning = inMemoryPartitionPruning
+
+ override protected def beforeAll() {
+ // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
+ setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
+ val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
+ rawData.registerTempTable("intData")
+
+ // Enable in-memory partition pruning
+ setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ }
+
+ override protected def afterAll() {
+ setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
+ setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
+ }
+
+ before {
+ cacheTable("intData")
+ }
+
+ after {
+ uncacheTable("intData")
+ }
+
+ // Comparisons
+ checkBatchPruning("i = 1", Seq(1), 1, 1)
+ checkBatchPruning("1 = i", Seq(1), 1, 1)
+ checkBatchPruning("i < 12", 1 to 11, 1, 2)
+ checkBatchPruning("i <= 11", 1 to 11, 1, 2)
+ checkBatchPruning("i > 88", 89 to 100, 1, 2)
+ checkBatchPruning("i >= 89", 89 to 100, 1, 2)
+ checkBatchPruning("12 > i", 1 to 11, 1, 2)
+ checkBatchPruning("11 >= i", 1 to 11, 1, 2)
+ checkBatchPruning("88 < i", 89 to 100, 1, 2)
+ checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+
+ // Conjunction and disjunction
+ checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
+ checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
+ checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
+
+ // With unsupported predicate
+ checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
+ checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10)
+
+ def checkBatchPruning(
+ filter: String,
+ expectedQueryResult: Seq[Int],
+ expectedReadPartitions: Int,
+ expectedReadBatches: Int) {
+
+ test(filter) {
+ val query = sql(s"SELECT * FROM intData WHERE $filter")
+ assertResult(expectedQueryResult.toArray, "Wrong query result") {
+ query.collect().map(_.head).toArray
+ }
+
+ val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
+ case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
+ }.head
+
+ assert(readBatches === expectedReadBatches, "Wrong number of read batches")
+ assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index 5fba004809..e01cc8b4d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
import org.apache.spark.sql.Row
-import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats}
+import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
class BooleanBitSetSuite extends FunSuite {
@@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite {
// Tests encoder
// -------------
- val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet)
+ val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet)
val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN))
val values = rows.map(_.head)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d8ae2a2677..d2969d906c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite {
testDictionaryEncoding(new StringColumnStats, STRING)
def testDictionaryEncoding[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 17619dcf97..322f447c24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite {
testIntegralDelta(new LongColumnStats, LONG, LongDelta)
def testIntegralDelta[I <: IntegralType](
- columnStats: NativeColumnStats[I],
+ columnStats: ColumnStats,
columnType: NativeColumnType[I],
scheme: IntegralDelta[I]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 40115beb98..218c09ac26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
class RunLengthEncodingSuite extends FunSuite {
- testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+ testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
testRunLengthEncoding(new ByteColumnStats, BYTE)
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
@@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new StringColumnStats, STRING)
def testRunLengthEncoding[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index 72c19fa31d..7db723d648 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
class TestCompressibleColumnBuilder[T <: NativeType](
- override val columnStats: NativeColumnStats[T],
+ override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T],
override val schemes: Seq[CompressionScheme])
extends NativeColumnBuilder(columnStats, columnType)
@@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType](
object TestCompressibleColumnBuilder {
def apply[T <: NativeType](
- columnStats: NativeColumnStats[T],
+ columnStats: ColumnStats,
columnType: NativeColumnType[T],
scheme: CompressionScheme) = {