aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-02-04 18:37:58 -0800
committerReynold Xin <rxin@databricks.com>2016-02-04 18:37:58 -0800
commite3c75c6398b1241500343ff237e9bcf78b5396f9 (patch)
treea807084a9441b4d8b341f0905790343669c59622 /sql
parent6dbfc40776514c3a5667161ebe7829f4cc9c7529 (diff)
downloadspark-e3c75c6398b1241500343ff237e9bcf78b5396f9.tar.gz
spark-e3c75c6398b1241500343ff237e9bcf78b5396f9.tar.bz2
spark-e3c75c6398b1241500343ff237e9bcf78b5396f9.zip
[SPARK-12850][SQL] Support Bucket Pruning (Predicate Pushdown for Bucketed Tables)
JIRA: https://issues.apache.org/jira/browse/SPARK-12850 This PR is to support bucket pruning when the predicates are `EqualTo`, `EqualNullSafe`, `IsNull`, `In`, and `InSet`. Like HIVE, in this PR, the bucket pruning works when the bucketing key has one and only one column. So far, I do not find a way to verify how many buckets are actually scanned. However, I did verify it when doing the debug. Could you provide a suggestion how to do it properly? Thank you! cloud-fan yhuai rxin marmbrus BTW, we can add more cases to support complex predicate including `Or` and `And`. Please let me know if I should do it in this PR. Maybe we also need to add test cases to verify if bucket pruning works well for each data type. Author: gatorsmile <gatorsmile@gmail.com> Closes #10942 from gatorsmile/pruningBuckets.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala15
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala162
3 files changed, 245 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index da9320ffb6..c24967abeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.collection.BitSet
/**
* A Strategy for planning scans over data sources defined using the sources API.
@@ -97,10 +99,15 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(partitionAndNormalColumnAttrs ++ projects).toSeq
}
+ // Prune the buckets based on the pushed filters that do not contain partitioning key
+ // since the bucketing key is not allowed to use the columns in partitioning key
+ val bucketSet = getBuckets(pushedFilters, t.getBucketSpec)
+
val scan = buildPartitionedTableScan(
l,
partitionAndNormalColumnProjs,
pushedFilters,
+ bucketSet,
t.partitionSpec.partitionColumns,
selectedPartitions)
@@ -124,11 +131,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val sharedHadoopConf = SparkHadoopUtil.get.conf
val confBroadcast =
t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
+ // Prune the buckets based on the filters
+ val bucketSet = getBuckets(filters, t.getBucketSpec)
pruneFilterProject(
l,
projects,
filters,
- (a, f) => t.buildInternalScan(a.map(_.name).toArray, f, t.paths, confBroadcast)) :: Nil
+ (a, f) =>
+ t.buildInternalScan(a.map(_.name).toArray, f, bucketSet, t.paths, confBroadcast)) :: Nil
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
execution.PhysicalRDD.createFromDataSource(
@@ -150,6 +160,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
logicalRelation: LogicalRelation,
projections: Seq[NamedExpression],
filters: Seq[Expression],
+ buckets: Option[BitSet],
partitionColumns: StructType,
partitions: Array[Partition]): SparkPlan = {
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
@@ -174,7 +185,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// assuming partition columns data stored in data files are always consistent with those
// partition values encoded in partition directory paths.
val dataRows = relation.buildInternalScan(
- requiredDataColumns.map(_.name).toArray, filters, Array(dir), confBroadcast)
+ requiredDataColumns.map(_.name).toArray, filters, buckets, Array(dir), confBroadcast)
// Merges data values with partition values.
mergeWithPartitionValues(
@@ -251,6 +262,69 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
}
}
+ // Get the bucket ID based on the bucketing values.
+ // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+ def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
+ val mutableRow = new SpecificMutableRow(Seq(bucketColumn.dataType))
+ mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null)
+ val bucketIdGeneration = UnsafeProjection.create(
+ HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
+ bucketColumn :: Nil)
+
+ bucketIdGeneration(mutableRow).getInt(0)
+ }
+
+ // Get the bucket BitSet by reading the filters that only contains bucketing keys.
+ // Note: When the returned BitSet is None, no pruning is possible.
+ // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
+ private def getBuckets(
+ filters: Seq[Expression],
+ bucketSpec: Option[BucketSpec]): Option[BitSet] = {
+
+ if (bucketSpec.isEmpty ||
+ bucketSpec.get.numBuckets == 1 ||
+ bucketSpec.get.bucketColumnNames.length != 1) {
+ // None means all the buckets need to be scanned
+ return None
+ }
+
+ // Just get the first because bucketing pruning only works when the column has one column
+ val bucketColumnName = bucketSpec.get.bucketColumnNames.head
+ val numBuckets = bucketSpec.get.numBuckets
+ val matchedBuckets = new BitSet(numBuckets)
+ matchedBuckets.clear()
+
+ filters.foreach {
+ case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, v))
+ // Because we only convert In to InSet in Optimizer when there are more than certain
+ // items. So it is possible we still get an In expression here that needs to be pushed
+ // down.
+ case expressions.In(a: Attribute, list)
+ if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
+ val hSet = list.map(e => e.eval(EmptyRow))
+ hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
+ case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
+ matchedBuckets.set(getBucketId(a, numBuckets, null))
+ case _ =>
+ }
+
+ logInfo {
+ val selected = matchedBuckets.cardinality()
+ val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
+ s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."
+ }
+
+ // None means all the buckets need to be scanned
+ if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
+ }
+
protected def prunePartitions(
predicates: Seq[Expression],
partitionSpec: PartitionSpec): Seq[Partition] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 299fc6efbb..737be7dfd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.BitSet
/**
* ::DeveloperApi::
@@ -722,6 +723,7 @@ abstract class HadoopFsRelation private[sql](
final private[sql] def buildInternalScan(
requiredColumns: Array[String],
filters: Array[Filter],
+ bucketSet: Option[BitSet],
inputPaths: Array[String],
broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = {
val inputStatuses = inputPaths.flatMap { input =>
@@ -743,9 +745,16 @@ abstract class HadoopFsRelation private[sql](
// id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
// bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
- groupedBucketFiles.get(bucketId).map { inputStatuses =>
- buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
- }.getOrElse(sqlContext.emptyResult)
+ // If the current bucketId is not set in the bucket bitSet, skip scanning it.
+ if (bucketSet.nonEmpty && !bucketSet.get.get(bucketId)){
+ sqlContext.emptyResult
+ } else {
+ // When all the buckets need a scan (i.e., bucketSet is equal to None)
+ // or when the current bucket need a scan (i.e., the bit of bucketId is set to true)
+ groupedBucketFiles.get(bucketId).map { inputStatuses =>
+ buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
+ }.getOrElse(sqlContext.emptyResult)
+ }
}
new UnionRDD(sqlContext.sparkContext, perBucketRows)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 150d0c7486..9ba645626f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -19,22 +19,28 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.Exchange
-import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.BitSet
class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
+ private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+ private val nullDF = (for {
+ i <- 0 to 50
+ s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
+ } yield (i % 5, s, i % 13)).toDF("i", "j", "k")
+
test("read bucketed data") {
- val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
withTable("bucketed_table") {
df.write
.format("parquet")
@@ -59,6 +65,152 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}
+ // To verify if the bucket pruning works, this function checks two conditions:
+ // 1) Check if the pruned buckets (before filtering) are empty.
+ // 2) Verify the final result is the same as the expected one
+ private def checkPrunedAnswers(
+ bucketSpec: BucketSpec,
+ bucketValues: Seq[Integer],
+ filterCondition: Column,
+ originalDataFrame: DataFrame): Unit = {
+
+ val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k")
+ val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
+ // Limit: bucket pruning only works when the bucket column has one and only one column
+ assert(bucketColumnNames.length == 1)
+ val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
+ val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
+ val matchedBuckets = new BitSet(numBuckets)
+ bucketValues.foreach { value =>
+ matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value))
+ }
+
+ // Filter could hide the bug in bucket pruning. Thus, skipping all the filters
+ val rdd = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
+ .find(_.isInstanceOf[PhysicalRDD])
+ assert(rdd.isDefined)
+
+ val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
+ if (matchedBuckets.get(index % numBuckets)) Iterator(true) else Iterator(iter.isEmpty)
+ }
+ // checking if all the pruned buckets are empty
+ assert(checkedResult.collect().forall(_ == true))
+
+ checkAnswer(
+ bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
+ originalDataFrame.filter(filterCondition).orderBy("i", "j", "k"))
+ }
+
+ test("read partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ // Case 1: EqualTo
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+
+ // Case 2: EqualNullSafe
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" <=> j,
+ df)
+
+ // Case 3: In
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = Seq(j, j + 1, j + 2, j + 3),
+ filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
+ df)
+ }
+ }
+ }
+
+ test("read non-partitioning bucketed tables with bucket pruning filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j,
+ df)
+ }
+ }
+ }
+
+ test("read partitioning bucketed tables having null in bucketing key") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ nullDF.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ // Case 1: isNull
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j".isNull,
+ nullDF)
+
+ // Case 2: <=> null
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = null :: Nil,
+ filterCondition = $"j" <=> null,
+ nullDF)
+ }
+ }
+
+ test("read partitioning bucketed tables having composite filters") {
+ withTable("bucketed_table") {
+ val numBuckets = 8
+ val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
+ // json does not support predicate push-down, and thus json is used here
+ df.write
+ .format("json")
+ .partitionBy("i")
+ .bucketBy(numBuckets, "j")
+ .saveAsTable("bucketed_table")
+
+ for (j <- 0 until 13) {
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"k" > $"j",
+ df)
+
+ checkPrunedAnswers(
+ bucketSpec,
+ bucketValues = j :: Nil,
+ filterCondition = $"j" === j && $"i" > j % 5,
+ df)
+ }
+ }
+ }
+
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")