From 2838bf8aadd5228829c1a869863bc4da7877fdfb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 26 Oct 2014 16:10:09 -0700 Subject: [SPARK-3537][SPARK-3914][SQL] Refines in-memory columnar table statistics This PR refines in-memory columnar table statistics: 1. adds 2 more statistics for in-memory table columns: `count` and `sizeInBytes` 1. adds filter pushdown support for `IS NULL` and `IS NOT NULL`. 1. caches and propagates statistics in `InMemoryRelation` once the underlying cached RDD is materialized. Statistics are collected to driver side with an accumulator. This PR also fixes SPARK-3914 by properly propagating in-memory statistics. Author: Cheng Lian Closes #2860 from liancheng/propagates-in-mem-stats and squashes the following commits: 0cc5271 [Cheng Lian] Restricts visibility of o.a.s.s.c.p.l.Statistics c5ff904 [Cheng Lian] Fixes test table name conflict a8c818d [Cheng Lian] Refines tests 1d01074 [Cheng Lian] Bug fix: shouldn't call STRING.actualSize on null string value 7dc6a34 [Cheng Lian] Adds more in-memory table statistics and propagates them properly --- .../sql/catalyst/expressions/AttributeMap.scala | 10 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 31 +++--- .../apache/spark/sql/columnar/ColumnStats.scala | 122 +++++++++++---------- .../sql/columnar/InMemoryColumnarTableScan.scala | 101 ++++++++++------- .../apache/spark/sql/execution/ExistingRDD.scala | 11 +- .../apache/spark/sql/parquet/ParquetRelation.scala | 3 +- .../org/apache/spark/sql/CachedTableSuite.scala | 11 +- .../test/scala/org/apache/spark/sql/TestData.scala | 16 +-- .../spark/sql/columnar/ColumnStatsSuite.scala | 6 + .../sql/columnar/PartitionBatchPruningSuite.scala | 76 ++++++++----- .../apache/spark/sql/execution/PlannerSuite.scala | 20 ++++ 11 files changed, 240 insertions(+), 167 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 8364379644..82e760b6c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -23,8 +23,7 @@ package org.apache.spark.sql.catalyst.expressions * of the name, or the expected nullability). */ object AttributeMap { - def apply[A](kvs: Seq[(Attribute, A)]) = - new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap) + def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) @@ -32,10 +31,9 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = - (baseMap.map(_._2) + kv).toMap + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv - override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator + override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator - override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap + override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 882e9c6110..ed578e081b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -26,25 +26,24 @@ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees +/** + * Estimates of various statistics. The default estimation logic simply lazily multiplies the + * corresponding statistic produced by the children. To override this behavior, override + * `statistics` and assign it an overriden version of `Statistics`. + * + * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the + * performance of the implementations. The reason is that estimations might get triggered in + * performance-critical processes, such as query plan planning. + * + * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it + * defaults to the product of children's `sizeInBytes`. + */ +private[sql] case class Statistics(sizeInBytes: BigInt) + abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { self: Product => - /** - * Estimates of various statistics. The default estimation logic simply lazily multiplies the - * corresponding statistic produced by the children. To override this behavior, override - * `statistics` and assign it an overriden version of `Statistics`. - * - * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the - * performance of the implementations. The reason is that estimations might get triggered in - * performance-critical processes, such as query plan planning. - * - * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it - * defaults to the product of children's `sizeInBytes`. - */ - case class Statistics( - sizeInBytes: BigInt - ) - lazy val statistics: Statistics = { + def statistics: Statistics = { if (children.size == 0) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b34ab255d0..b9f9f82700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, Attri import org.apache.spark.sql.catalyst.types._ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { - val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)() - val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)() - val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)() + val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)() - val schema = Seq(lowerBound, upperBound, nullCount) + val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) } private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { @@ -45,10 +47,21 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri * brings significant performance penalty. */ private[sql] sealed trait ColumnStats extends Serializable { + protected var count = 0 + protected var nullCount = 0 + protected var sizeInBytes = 0L + /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int): Unit + def gatherStats(row: Row, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 + } + count += 1 + } /** * Column statistics represented as a single row, currently including closed lower bound, closed @@ -65,163 +78,154 @@ private[sql] class NoopColumnStats extends ColumnStats { } private[sql] class ByteColumnStats extends ColumnStats { - var upper = Byte.MinValue - var lower = Byte.MaxValue - var nullCount = 0 + protected var upper = Byte.MinValue + protected var lower = Byte.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += BYTE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class ShortColumnStats extends ColumnStats { - var upper = Short.MinValue - var lower = Short.MaxValue - var nullCount = 0 + protected var upper = Short.MinValue + protected var lower = Short.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += SHORT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class LongColumnStats extends ColumnStats { - var upper = Long.MinValue - var lower = Long.MaxValue - var nullCount = 0 + protected var upper = Long.MinValue + protected var lower = Long.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += LONG.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DoubleColumnStats extends ColumnStats { - var upper = Double.MinValue - var lower = Double.MaxValue - var nullCount = 0 + protected var upper = Double.MinValue + protected var lower = Double.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += DOUBLE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class FloatColumnStats extends ColumnStats { - var upper = Float.MinValue - var lower = Float.MaxValue - var nullCount = 0 + protected var upper = Float.MinValue + protected var lower = Float.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += FLOAT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class IntColumnStats extends ColumnStats { - var upper = Int.MinValue - var lower = Int.MaxValue - var nullCount = 0 + protected var upper = Int.MinValue + protected var lower = Int.MaxValue override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - } else { - nullCount += 1 + sizeInBytes += INT.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class StringColumnStats extends ColumnStats { - var upper: String = null - var lower: String = null - var nullCount = 0 + protected var upper: String = null + protected var lower: String = null override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += STRING.actualSize(row, ordinal) } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class DateColumnStats extends ColumnStats { - var upper: Date = null - var lower: Date = null - var nullCount = 0 + protected var upper: Date = null + protected var lower: Date = null override def gatherStats(row: Row, ordinal: Int) { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Date] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += DATE.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } private[sql] class TimestampColumnStats extends ColumnStats { - var upper: Timestamp = null - var lower: Timestamp = null - var nullCount = 0 + protected var upper: Timestamp = null + protected var lower: Timestamp = null override def gatherStats(row: Row, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - } else { - nullCount += 1 + sizeInBytes += TIMESTAMP.defaultSize } } - def collectedStatistics = Row(lower, upper, nullCount) + def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 22ab0e2613..ee63134f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel @@ -45,15 +47,51 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan) - (private var _cachedColumnBuffers: RDD[CachedBatch] = null) + child: SparkPlan)( + private var _cachedColumnBuffers: RDD[CachedBatch] = null, + private var _statistics: Statistics = null) extends LogicalPlan with MultiInstanceRelation { - override lazy val statistics = - Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + private val batchStats = + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) val partitionStatistics = new PartitionStatistics(output) + private def computeSizeInBytes = { + val sizeOfRow: Expression = + BindReferences.bindReference( + output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), + partitionStatistics.schema) + + batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum + } + + // Statistics propagation contracts: + // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data + // 2. Only propagate statistics when `_statistics` is non-null + private def statisticsToBePropagated = if (_statistics == null) { + val updatedStats = statistics + if (_statistics == null) null else updatedStats + } else { + _statistics + } + + override def statistics = if (_statistics == null) { + if (batchStats.value.isEmpty) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + } else { + // Underlying columnar RDD has been materialized, required information has also been collected + // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`. + _statistics = Statistics(sizeInBytes = computeSizeInBytes) + _statistics + } + } else { + // Pre-computed statistics + _statistics + } + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { @@ -91,6 +129,7 @@ private[sql] case class InMemoryRelation( val stats = Row.fromSeq( columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) } @@ -104,7 +143,8 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers) + newOutput, useCompression, batchSize, storageLevel, child)( + _cachedColumnBuffers, statisticsToBePropagated) } override def children = Seq.empty @@ -116,7 +156,8 @@ private[sql] case class InMemoryRelation( batchSize, storageLevel, child)( - _cachedColumnBuffers).asInstanceOf[this.type] + _cachedColumnBuffers, + statisticsToBePropagated).asInstanceOf[this.type] } def cachedColumnBuffers = _cachedColumnBuffers @@ -132,6 +173,8 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. val buildFilter: PartialFunction[Expression, Expression] = { @@ -144,44 +187,24 @@ private[sql] case class InMemoryColumnarTableScan( buildFilter(lhs) || buildFilter(rhs) case EqualTo(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l && l <= aStats.upperBound - + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound case EqualTo(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l && l <= aStats.upperBound - - case LessThan(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound < l - - case LessThan(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - l < aStats.upperBound - - case LessThanOrEqual(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l + statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThanOrEqual(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - l <= aStats.upperBound + case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l + case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - l < aStats.upperBound + case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l + case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound < l + case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound + case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => - val aStats = relation.partitionStatistics.forAttribute(a) - l <= aStats.upperBound + case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound + case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => - val aStats = relation.partitionStatistics.forAttribute(a) - aStats.lowerBound <= l + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } val partitionFilters = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2ddf513b6f..04c51a1ee4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -17,16 +17,13 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -import scala.reflect.runtime.universe.TypeTag - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -100,7 +97,7 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ override final def newInstance(): this.type = { SparkLogicalPlan( alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd) case _ => sys.error("Multiple instance of the same relation detected.") })(sqlContext).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 5ae768293a..82130b5459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -22,7 +22,6 @@ import java.io.IOException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction - import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType @@ -30,7 +29,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} /** * Relation that consists of data stored in a Parquet columnar format. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index da5a358df3..1a5d87d524 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -234,4 +234,13 @@ class CachedTableSuite extends QueryTest { uncacheTable("testData") assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } + + test("InMemoryRelation statistics") { + sql("CACHE TABLE testData") + table("testData").queryExecution.withCachedData.collect { + case cached: InMemoryRelation => + val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum + assert(cached.statistics.sizeInBytes === actualSizeInBytes) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 10b7979df7..1c21afc17e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -28,40 +28,40 @@ import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) object TestData { - val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))) + val testData = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) - val largeAndSmallInts: SchemaRDD = + val largeAndSmallInts = TestSQLContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil) + LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) - val testData2: SchemaRDD = + val testData2 = TestSQLContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil) + TestData2(3, 2) :: Nil).toSchemaRDD testData2.registerTempTable("testData2") case class BinaryData(a: Array[Byte], b: Int) - val binaryData: SchemaRDD = + val binaryData = TestSQLContext.sparkContext.parallelize( BinaryData("12".getBytes(), 1) :: BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil) + BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD binaryData.registerTempTable("binaryData") // TODO: There is no way to express null primitives as case classes currently... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 6bdf741134..a9f0851f88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite { assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index f53acc8c9f..9ba3c21017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -22,8 +22,6 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ -case class IntegerData(i: Int) - class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning @@ -31,8 +29,12 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) - rawData.registerTempTable("intData") + + val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val string = if (((key - 1) / 10) % 2 == 0) null else key.toString + TestData(key, string) + }, 5) + pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") @@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be } before { - cacheTable("intData") + cacheTable("pruningData") } after { - uncacheTable("intData") + uncacheTable("pruningData") } // Comparisons - checkBatchPruning("i = 1", Seq(1), 1, 1) - checkBatchPruning("1 = i", Seq(1), 1, 1) - checkBatchPruning("i < 12", 1 to 11, 1, 2) - checkBatchPruning("i <= 11", 1 to 11, 1, 2) - checkBatchPruning("i > 88", 89 to 100, 1, 2) - checkBatchPruning("i >= 89", 89 to 100, 1, 2) - checkBatchPruning("12 > i", 1 to 11, 1, 2) - checkBatchPruning("11 >= i", 1 to 11, 1, 2) - checkBatchPruning("88 < i", 89 to 100, 1, 2) - checkBatchPruning("89 <= i", 89 to 100, 1, 2) + checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1)) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE key >= 89", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE 12 > key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) + checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + + // IS NULL + checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { + (1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90) + } + + // IS NOT NULL + checkBatchPruning("SELECT key FROM pruningData WHERE value IS NOT NULL", 5, 5) { + (11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100) + } // Conjunction and disjunction - checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) - checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) - checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) - checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2) + checkBatchPruning("SELECT key FROM pruningData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100)) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) { + Seq(1) ++ (79 to 91) + } // With unsupported predicate - checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) - checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10) + checkBatchPruning("SELECT key FROM pruningData WHERE NOT (key < 88)", 1, 2)(88 to 100) + checkBatchPruning("SELECT key FROM pruningData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11) + + { + val seq = (1 to 30).mkString(", ") + checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100) + } def checkBatchPruning( - filter: String, - expectedQueryResult: Seq[Int], + query: String, expectedReadPartitions: Int, - expectedReadBatches: Int): Unit = { + expectedReadBatches: Int)( + expectedQueryResult: => Seq[Int]): Unit = { - test(filter) { - val query = sql(s"SELECT * FROM intData WHERE $filter") + test(query) { + val schemaRdd = sql(query) assertResult(expectedQueryResult.toArray, "Wrong query result") { - query.collect().map(_.head).toArray + schemaRdd.collect().map(_.head).toArray } - val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f14ffca0e4..a5af71acfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -76,4 +76,24 @@ class PlannerSuite extends FunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } + + test("InMemoryRelation statistics propagation") { + val origThreshold = autoBroadcastJoinThreshold + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") + + val a = testData.as('a) + val b = table("tiny").as('b) + val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + } } -- cgit v1.2.3