diff options
author | Adrian Ionescu <adrian@databricks.com> | 2017-04-03 08:48:49 -0700 |
---|---|---|
committer | Xiao Li <gatorsmile@gmail.com> | 2017-04-03 08:48:49 -0700 |
commit | 703c42c398fefd3f7f60e1c503c4df50251f8dcf (patch) | |
tree | 54616a4a243da8ab077d871a3ec0d5c0f057394b | |
parent | 4fa1a43af6b5a6abaef7e04cacb2617a2e92d816 (diff) | |
download | spark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.tar.gz spark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.tar.bz2 spark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.zip |
[SPARK-20194] Add support for partition pruning to in-memory catalog
## What changes were proposed in this pull request?
This patch implements `listPartitionsByFilter()` for `InMemoryCatalog` and thus resolves an outstanding TODO causing the `PruneFileSourcePartitions` optimizer rule not to apply when "spark.sql.catalogImplementation" is set to "in-memory" (which is the default).
The change is straightforward: it extracts the code for further filtering of the list of partitions returned by the metastore's `getPartitionsByFilter()` out from `HiveExternalCatalog` into `ExternalCatalogUtils` and calls this new function from `InMemoryCatalog` on the whole list of partitions.
Now that this method is implemented we can always pass the `CatalogTable` to the `DataSource` in `FindDataSourceTable`, so that the latter is resolved to a relation with a `CatalogFileIndex`, which is what the `PruneFileSourcePartitions` rule matches for.
## How was this patch tested?
Ran existing tests and added new test for `listPartitionsByFilter` in `ExternalCatalogSuite`, which is subclassed by both `InMemoryCatalogSuite` and `HiveExternalCatalogSuite`.
Author: Adrian Ionescu <adrian@databricks.com>
Closes #17510 from adrian-ionescu/InMemoryCatalog.
7 files changed, 85 insertions, 45 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a8693dcca5..254eedfe77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate} object ExternalCatalogUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't @@ -125,6 +126,38 @@ object ExternalCatalogUtils { } escapePathName(col) + "=" + partitionString } + + def prunePartitionsByFilter( + catalogTable: CatalogTable, + inputPartitions: Seq[CatalogTablePartition], + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { + if (predicates.isEmpty) { + inputPartitions + } else { + val partitionSchema = catalogTable.partitionSchema + val partitionColumnNames = catalogTable.partitionColumnNames.toSet + + val nonPartitionPruningPredicates = predicates.filterNot { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + if (nonPartitionPruningPredicates.nonEmpty) { + throw new AnalysisException("Expected only partition pruning predicates: " + + nonPartitionPruningPredicates) + } + + val boundPredicate = + InterpretedPredicate.create(predicates.reduce(And).transform { + case att: AttributeReference => + val index = partitionSchema.indexWhere(_.name == att.name) + BoundReference(index, partitionSchema(index).dataType, nullable = true) + }) + + inputPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } + } + } } object CatalogUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index cdf618aef9..9ca1c71d1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types.StructType @@ -556,9 +556,9 @@ class InMemoryCatalog( table: String, predicates: Seq[Expression], defaultTimeZoneId: String): Seq[CatalogTablePartition] = { - // TODO: Provide an implementation - throw new UnsupportedOperationException( - "listPartitionsByFilter is not implemented for InMemoryCatalog") + val catalogTable = getTable(db, table) + val allPartitions = listPartitions(db, table) + prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 7820f39d96..42db4398e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import java.util.TimeZone import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("list partitions by filter") { + val tz = TimeZone.getDefault.getID + val catalog = newBasicCatalog() + + def checkAnswer( + table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition]) + : Unit = { + + assertResult(expected.map(_.spec)) { + catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz) + .map(_.spec).toSet + } + } + + val tbl2 = catalog.getTable("db2", "tbl2") + + checkAnswer(tbl2, Seq.empty, Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 2), Set.empty) + checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) + checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) + checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1)) + + intercept[AnalysisException] { + try { + checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty) + } catch { + // HiveExternalCatalog may be the first one to notice and throw an exception, which will + // then be caught and converted to a RuntimeException with a descriptive message. + case ex: RuntimeException if ex.getMessage.contains("MetaException") => + throw new AnalysisException(ex.getMessage) + } + } + } + test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index bddf5af23e..c350d8bcba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) val cache = sparkSession.sessionState.catalog.tableRelationCache - val withHiveSupport = - sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { @@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] bucketSpec = table.bucketSpec, className = table.provider.get, options = table.storage.properties ++ pathOption, - // TODO: improve `InMemoryCatalog` and remove this limitation. - catalogTable = if (withHiveSupport) Some(table) else None) + catalogTable = Some(table)) LogicalRelation( dataSource.resolveRelation(checkFilesExist = false), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 33b21be372..f0e35dff57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) - val partitionColumnNames = catalogTable.partitionColumnNames.toSet - val nonPartitionPruningPredicates = predicates.filterNot { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - if (nonPartitionPruningPredicates.nonEmpty) { - sys.error("Expected only partition pruning predicates: " + - predicates.reduceLeft(And)) - } + val partColNameMap = buildLowerCasePartColNameMap(catalogTable) - val partitionSchema = catalogTable.partitionSchema - val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) - - if (predicates.nonEmpty) { - val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part => + val clientPrunedPartitions = + client.getPartitionsByFilter(rawTable, predicates).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } - val boundPredicate = - InterpretedPredicate.create(predicates.reduce(And).transform { - case att: AttributeReference => - val index = partitionSchema.indexWhere(_.name == att.name) - BoundReference(index, partitionSchema(index).dataType, nullable = true) - }) - clientPrunedPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) - } - } else { - client.getPartitions(catalogTable).map { part => - part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) - } - } + prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId) } // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d55c41e5c9..2e35f39839 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys.asScala + lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) .map(col => col.getName).toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 4349f1aa23..bd54c043c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.types.StructType @@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { import utils._ - test("list partitions by filter") { - val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") - assert(selectedPartitions.length == 1) - assert(selectedPartitions.head.spec == part1.spec) - } - test("SPARK-18647: do not put provider in table properties for Hive serde table") { val catalog = newBasicCatalog() val hiveTable = CatalogTable( |