aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-10-26 16:10:09 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-26 16:10:09 -0700
commit2838bf8aadd5228829c1a869863bc4da7877fdfb (patch)
tree474e9dc739631b81c20c812c38413d969fe47f2c
parent879a16585808e8fe34bdede741565efc4c9f9bb3 (diff)
downloadspark-2838bf8aadd5228829c1a869863bc4da7877fdfb.tar.gz
spark-2838bf8aadd5228829c1a869863bc4da7877fdfb.tar.bz2
spark-2838bf8aadd5228829c1a869863bc4da7877fdfb.zip
[SPARK-3537][SPARK-3914][SQL] Refines in-memory columnar table statistics
This PR refines in-memory columnar table statistics: 1. adds 2 more statistics for in-memory table columns: `count` and `sizeInBytes` 1. adds filter pushdown support for `IS NULL` and `IS NOT NULL`. 1. caches and propagates statistics in `InMemoryRelation` once the underlying cached RDD is materialized. Statistics are collected to driver side with an accumulator. This PR also fixes SPARK-3914 by properly propagating in-memory statistics. Author: Cheng Lian <lian@databricks.com> Closes #2860 from liancheng/propagates-in-mem-stats and squashes the following commits: 0cc5271 [Cheng Lian] Restricts visibility of o.a.s.s.c.p.l.Statistics c5ff904 [Cheng Lian] Fixes test table name conflict a8c818d [Cheng Lian] Refines tests 1d01074 [Cheng Lian] Bug fix: shouldn't call STRING.actualSize on null string value 7dc6a34 [Cheng Lian] Adds more in-memory table statistics and propagates them properly
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala122
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala101
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala76
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala20
11 files changed, 240 insertions, 167 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 8364379644..82e760b6c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -23,8 +23,7 @@ package org.apache.spark.sql.catalyst.expressions
* of the name, or the expected nullability).
*/
object AttributeMap {
- def apply[A](kvs: Seq[(Attribute, A)]) =
- new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
+ def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
@@ -32,10 +31,9 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
- override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
- (baseMap.map(_._2) + kv).toMap
+ override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
- override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
+ override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
- override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
+ override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 882e9c6110..ed578e081b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -26,25 +26,24 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
+/**
+ * Estimates of various statistics. The default estimation logic simply lazily multiplies the
+ * corresponding statistic produced by the children. To override this behavior, override
+ * `statistics` and assign it an overriden version of `Statistics`.
+ *
+ * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
+ * performance of the implementations. The reason is that estimations might get triggered in
+ * performance-critical processes, such as query plan planning.
+ *
+ * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
+ * defaults to the product of children's `sizeInBytes`.
+ */
+private[sql] case class Statistics(sizeInBytes: BigInt)
+
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
self: Product =>
- /**
- * Estimates of various statistics. The default estimation logic simply lazily multiplies the
- * corresponding statistic produced by the children. To override this behavior, override
- * `statistics` and assign it an overriden version of `Statistics`.
- *
- * '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
- * performance of the implementations. The reason is that estimations might get triggered in
- * performance-critical processes, such as query plan planning.
- *
- * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
- * defaults to the product of children's `sizeInBytes`.
- */
- case class Statistics(
- sizeInBytes: BigInt
- )
- lazy val statistics: Statistics = {
+ def statistics: Statistics = {
if (children.size == 0) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index b34ab255d0..b9f9f82700 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, Attri
import org.apache.spark.sql.catalyst.types._
private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
- val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
- val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
- val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+ val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
+ val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)()
+ val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
+ val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)()
+ val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)()
- val schema = Seq(lowerBound, upperBound, nullCount)
+ val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes)
}
private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
@@ -45,10 +47,21 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri
* brings significant performance penalty.
*/
private[sql] sealed trait ColumnStats extends Serializable {
+ protected var count = 0
+ protected var nullCount = 0
+ protected var sizeInBytes = 0L
+
/**
* Gathers statistics information from `row(ordinal)`.
*/
- def gatherStats(row: Row, ordinal: Int): Unit
+ def gatherStats(row: Row, ordinal: Int): Unit = {
+ if (row.isNullAt(ordinal)) {
+ nullCount += 1
+ // 4 bytes for null position
+ sizeInBytes += 4
+ }
+ count += 1
+ }
/**
* Column statistics represented as a single row, currently including closed lower bound, closed
@@ -65,163 +78,154 @@ private[sql] class NoopColumnStats extends ColumnStats {
}
private[sql] class ByteColumnStats extends ColumnStats {
- var upper = Byte.MinValue
- var lower = Byte.MaxValue
- var nullCount = 0
+ protected var upper = Byte.MinValue
+ protected var lower = Byte.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += BYTE.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class ShortColumnStats extends ColumnStats {
- var upper = Short.MinValue
- var lower = Short.MaxValue
- var nullCount = 0
+ protected var upper = Short.MinValue
+ protected var lower = Short.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += SHORT.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class LongColumnStats extends ColumnStats {
- var upper = Long.MinValue
- var lower = Long.MaxValue
- var nullCount = 0
+ protected var upper = Long.MinValue
+ protected var lower = Long.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += LONG.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DoubleColumnStats extends ColumnStats {
- var upper = Double.MinValue
- var lower = Double.MaxValue
- var nullCount = 0
+ protected var upper = Double.MinValue
+ protected var lower = Double.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += DOUBLE.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FloatColumnStats extends ColumnStats {
- var upper = Float.MinValue
- var lower = Float.MaxValue
- var nullCount = 0
+ protected var upper = Float.MinValue
+ protected var lower = Float.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += FLOAT.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class IntColumnStats extends ColumnStats {
- var upper = Int.MinValue
- var lower = Int.MaxValue
- var nullCount = 0
+ protected var upper = Int.MinValue
+ protected var lower = Int.MaxValue
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += INT.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class StringColumnStats extends ColumnStats {
- var upper: String = null
- var lower: String = null
- var nullCount = 0
+ protected var upper: String = null
+ protected var lower: String = null
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getString(ordinal)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += STRING.actualSize(row, ordinal)
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DateColumnStats extends ColumnStats {
- var upper: Date = null
- var lower: Date = null
- var nullCount = 0
+ protected var upper: Date = null
+ protected var lower: Date = null
override def gatherStats(row: Row, ordinal: Int) {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Date]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += DATE.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class TimestampColumnStats extends ColumnStats {
- var upper: Timestamp = null
- var lower: Timestamp = null
- var nullCount = 0
+ protected var upper: Timestamp = null
+ protected var lower: Timestamp = null
override def gatherStats(row: Row, ordinal: Int): Unit = {
+ super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Timestamp]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
- } else {
- nullCount += 1
+ sizeInBytes += TIMESTAMP.defaultSize
}
}
- def collectedStatistics = Row(lower, upper, nullCount)
+ def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 22ab0e2613..ee63134f56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
import org.apache.spark.storage.StorageLevel
@@ -45,15 +47,51 @@ private[sql] case class InMemoryRelation(
useCompression: Boolean,
batchSize: Int,
storageLevel: StorageLevel,
- child: SparkPlan)
- (private var _cachedColumnBuffers: RDD[CachedBatch] = null)
+ child: SparkPlan)(
+ private var _cachedColumnBuffers: RDD[CachedBatch] = null,
+ private var _statistics: Statistics = null)
extends LogicalPlan with MultiInstanceRelation {
- override lazy val statistics =
- Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+ private val batchStats =
+ child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
val partitionStatistics = new PartitionStatistics(output)
+ private def computeSizeInBytes = {
+ val sizeOfRow: Expression =
+ BindReferences.bindReference(
+ output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
+ partitionStatistics.schema)
+
+ batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
+ }
+
+ // Statistics propagation contracts:
+ // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
+ // 2. Only propagate statistics when `_statistics` is non-null
+ private def statisticsToBePropagated = if (_statistics == null) {
+ val updatedStats = statistics
+ if (_statistics == null) null else updatedStats
+ } else {
+ _statistics
+ }
+
+ override def statistics = if (_statistics == null) {
+ if (batchStats.value.isEmpty) {
+ // Underlying columnar RDD hasn't been materialized, no useful statistics information
+ // available, return the default statistics.
+ Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes)
+ } else {
+ // Underlying columnar RDD has been materialized, required information has also been collected
+ // via the `batchStats` accumulator, compute the final statistics, and update `_statistics`.
+ _statistics = Statistics(sizeInBytes = computeSizeInBytes)
+ _statistics
+ }
+ } else {
+ // Pre-computed statistics
+ _statistics
+ }
+
// If the cached column buffers were not passed in, we calculate them in the constructor.
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
@@ -91,6 +129,7 @@ private[sql] case class InMemoryRelation(
val stats = Row.fromSeq(
columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
+ batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
}
@@ -104,7 +143,8 @@ private[sql] case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
- newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers)
+ newOutput, useCompression, batchSize, storageLevel, child)(
+ _cachedColumnBuffers, statisticsToBePropagated)
}
override def children = Seq.empty
@@ -116,7 +156,8 @@ private[sql] case class InMemoryRelation(
batchSize,
storageLevel,
child)(
- _cachedColumnBuffers).asInstanceOf[this.type]
+ _cachedColumnBuffers,
+ statisticsToBePropagated).asInstanceOf[this.type]
}
def cachedColumnBuffers = _cachedColumnBuffers
@@ -132,6 +173,8 @@ private[sql] case class InMemoryColumnarTableScan(
override def output: Seq[Attribute] = attributes
+ private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
+
// Returned filter predicate should return false iff it is impossible for the input expression
// to evaluate to `true' based on statistics collected about this partition batch.
val buildFilter: PartialFunction[Expression, Expression] = {
@@ -144,44 +187,24 @@ private[sql] case class InMemoryColumnarTableScan(
buildFilter(lhs) || buildFilter(rhs)
case EqualTo(a: AttributeReference, l: Literal) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound <= l && l <= aStats.upperBound
-
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
case EqualTo(l: Literal, a: AttributeReference) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound <= l && l <= aStats.upperBound
-
- case LessThan(a: AttributeReference, l: Literal) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound < l
-
- case LessThan(l: Literal, a: AttributeReference) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- l < aStats.upperBound
-
- case LessThanOrEqual(a: AttributeReference, l: Literal) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound <= l
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case LessThanOrEqual(l: Literal, a: AttributeReference) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- l <= aStats.upperBound
+ case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
+ case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
- case GreaterThan(a: AttributeReference, l: Literal) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- l < aStats.upperBound
+ case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
+ case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
- case GreaterThan(l: Literal, a: AttributeReference) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound < l
+ case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
+ case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
- case GreaterThanOrEqual(a: AttributeReference, l: Literal) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- l <= aStats.upperBound
+ case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
+ case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
- case GreaterThanOrEqual(l: Literal, a: AttributeReference) =>
- val aStats = relation.partitionStatistics.forAttribute(a)
- aStats.lowerBound <= l
+ case IsNull(a: Attribute) => statsFor(a).nullCount > 0
+ case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
}
val partitionFilters = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 2ddf513b6f..04c51a1ee4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-import scala.reflect.runtime.universe.TypeTag
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.{Row, SQLContext}
/**
* :: DeveloperApi ::
@@ -100,7 +97,7 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQ
override final def newInstance(): this.type = {
SparkLogicalPlan(
alreadyPlanned match {
- case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+ case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance()), rdd)
case _ => sys.error("Multiple instance of the same relation detected.")
})(sqlContext).asInstanceOf[this.type]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 5ae768293a..82130b5459 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -22,7 +22,6 @@ import java.io.IOException
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.permission.FsAction
-
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.metadata.CompressionCodecName
import parquet.schema.MessageType
@@ -30,7 +29,7 @@ import parquet.schema.MessageType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
/**
* Relation that consists of data stored in a Parquet columnar format.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index da5a358df3..1a5d87d524 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
+import org.apache.spark.sql.columnar._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
@@ -234,4 +234,13 @@ class CachedTableSuite extends QueryTest {
uncacheTable("testData")
assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted")
}
+
+ test("InMemoryRelation statistics") {
+ sql("CACHE TABLE testData")
+ table("testData").queryExecution.withCachedData.collect {
+ case cached: InMemoryRelation =>
+ val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum
+ assert(cached.statistics.sizeInBytes === actualSizeInBytes)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 10b7979df7..1c21afc17e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -28,40 +28,40 @@ import org.apache.spark.sql.test.TestSQLContext._
case class TestData(key: Int, value: String)
object TestData {
- val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString)))
+ val testData = TestSQLContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
testData.registerTempTable("testData")
case class LargeAndSmallInts(a: Int, b: Int)
- val largeAndSmallInts: SchemaRDD =
+ val largeAndSmallInts =
TestSQLContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil)
+ LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
- val testData2: SchemaRDD =
+ val testData2 =
TestSQLContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil)
+ TestData2(3, 2) :: Nil).toSchemaRDD
testData2.registerTempTable("testData2")
case class BinaryData(a: Array[Byte], b: Int)
- val binaryData: SchemaRDD =
+ val binaryData =
TestSQLContext.sparkContext.parallelize(
BinaryData("12".getBytes(), 1) ::
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil)
+ BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
binaryData.registerTempTable("binaryData")
// TODO: There is no way to express null primitives as case classes currently...
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 6bdf741134..a9f0851f88 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -61,6 +61,12 @@ class ColumnStatsSuite extends FunSuite {
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
+ assertResult(20, "Wrong row count")(stats(3))
+ assertResult(stats(4), "Wrong size in bytes") {
+ rows.map { row =>
+ if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+ }.sum
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index f53acc8c9f..9ba3c21017 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -22,8 +22,6 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
-case class IntegerData(i: Int)
-
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = columnBatchSize
val originalInMemoryPartitionPruning = inMemoryPartitionPruning
@@ -31,8 +29,12 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
- val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
- rawData.registerTempTable("intData")
+
+ val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+ val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
+ TestData(key, string)
+ }, 5)
+ pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
@@ -44,48 +46,64 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
}
before {
- cacheTable("intData")
+ cacheTable("pruningData")
}
after {
- uncacheTable("intData")
+ uncacheTable("pruningData")
}
// Comparisons
- checkBatchPruning("i = 1", Seq(1), 1, 1)
- checkBatchPruning("1 = i", Seq(1), 1, 1)
- checkBatchPruning("i < 12", 1 to 11, 1, 2)
- checkBatchPruning("i <= 11", 1 to 11, 1, 2)
- checkBatchPruning("i > 88", 89 to 100, 1, 2)
- checkBatchPruning("i >= 89", 89 to 100, 1, 2)
- checkBatchPruning("12 > i", 1 to 11, 1, 2)
- checkBatchPruning("11 >= i", 1 to 11, 1, 2)
- checkBatchPruning("88 < i", 89 to 100, 1, 2)
- checkBatchPruning("89 <= i", 89 to 100, 1, 2)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
+ checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key >= 89", 1, 2)(89 to 100)
+ checkBatchPruning("SELECT key FROM pruningData WHERE 12 > key", 1, 2)(1 to 11)
+ checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
+ checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
+ checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
+
+ // IS NULL
+ checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
+ (1 to 10) ++ (21 to 30) ++ (41 to 50) ++ (61 to 70) ++ (81 to 90)
+ }
+
+ // IS NOT NULL
+ checkBatchPruning("SELECT key FROM pruningData WHERE value IS NOT NULL", 5, 5) {
+ (11 to 20) ++ (31 to 40) ++ (51 to 60) ++ (71 to 80) ++ (91 to 100)
+ }
// Conjunction and disjunction
- checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3)
- checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2)
- checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4)
- checkBatchPruning("NOT (i < 88)", 88 to 100, 1, 2)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key > 8 AND key <= 21", 2, 3)(9 to 21)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR key > 99", 2, 2)(Seq(1, 100))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key < 2 OR (key > 78 AND key < 92)", 3, 4) {
+ Seq(1) ++ (79 to 91)
+ }
// With unsupported predicate
- checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2)
- checkBatchPruning(s"NOT (i in (${(1 to 30).mkString(",")}))", 31 to 100, 5, 10)
+ checkBatchPruning("SELECT key FROM pruningData WHERE NOT (key < 88)", 1, 2)(88 to 100)
+ checkBatchPruning("SELECT key FROM pruningData WHERE key < 12 AND key IS NOT NULL", 1, 2)(1 to 11)
+
+ {
+ val seq = (1 to 30).mkString(", ")
+ checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
+ }
def checkBatchPruning(
- filter: String,
- expectedQueryResult: Seq[Int],
+ query: String,
expectedReadPartitions: Int,
- expectedReadBatches: Int): Unit = {
+ expectedReadBatches: Int)(
+ expectedQueryResult: => Seq[Int]): Unit = {
- test(filter) {
- val query = sql(s"SELECT * FROM intData WHERE $filter")
+ test(query) {
+ val schemaRdd = sql(query)
assertResult(expectedQueryResult.toArray, "Wrong query result") {
- query.collect().map(_.head).toArray
+ schemaRdd.collect().map(_.head).toArray
}
- val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect {
+ val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect {
case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value)
}.head
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f14ffca0e4..a5af71acfc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -76,4 +76,24 @@ class PlannerSuite extends FunSuite {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
}
+
+ test("InMemoryRelation statistics propagation") {
+ val origThreshold = autoBroadcastJoinThreshold
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
+
+ testData.limit(3).registerTempTable("tiny")
+ sql("CACHE TABLE tiny")
+
+ val a = testData.as('a)
+ val b = table("tiny").as('b)
+ val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+
+ val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+ val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+
+ assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+ assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
+ }
}