aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrian Ionescu <adrian@databricks.com>2017-04-03 08:48:49 -0700
committerXiao Li <gatorsmile@gmail.com>2017-04-03 08:48:49 -0700
commit703c42c398fefd3f7f60e1c503c4df50251f8dcf (patch)
tree54616a4a243da8ab077d871a3ec0d5c0f057394b
parent4fa1a43af6b5a6abaef7e04cacb2617a2e92d816 (diff)
downloadspark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.tar.gz
spark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.tar.bz2
spark-703c42c398fefd3f7f60e1c503c4df50251f8dcf.zip
[SPARK-20194] Add support for partition pruning to in-memory catalog
## What changes were proposed in this pull request? This patch implements `listPartitionsByFilter()` for `InMemoryCatalog` and thus resolves an outstanding TODO causing the `PruneFileSourcePartitions` optimizer rule not to apply when "spark.sql.catalogImplementation" is set to "in-memory" (which is the default). The change is straightforward: it extracts the code for further filtering of the list of partitions returned by the metastore's `getPartitionsByFilter()` out from `HiveExternalCatalog` into `ExternalCatalogUtils` and calls this new function from `InMemoryCatalog` on the whole list of partitions. Now that this method is implemented we can always pass the `CatalogTable` to the `DataSource` in `FindDataSourceTable`, so that the latter is resolved to a relation with a `CatalogFileIndex`, which is what the `PruneFileSourcePartitions` rule matches for. ## How was this patch tested? Ran existing tests and added new test for `listPartitionsByFilter` in `ExternalCatalogSuite`, which is subclassed by both `InMemoryCatalogSuite` and `HiveExternalCatalogSuite`. Author: Adrian Ionescu <adrian@databricks.com> Closes #17510 from adrian-ionescu/InMemoryCatalog.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala33
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala33
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala8
7 files changed, 85 insertions, 45 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
index a8693dcca5..254eedfe77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}
object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
@@ -125,6 +126,38 @@ object ExternalCatalogUtils {
}
escapePathName(col) + "=" + partitionString
}
+
+ def prunePartitionsByFilter(
+ catalogTable: CatalogTable,
+ inputPartitions: Seq[CatalogTablePartition],
+ predicates: Seq[Expression],
+ defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
+ if (predicates.isEmpty) {
+ inputPartitions
+ } else {
+ val partitionSchema = catalogTable.partitionSchema
+ val partitionColumnNames = catalogTable.partitionColumnNames.toSet
+
+ val nonPartitionPruningPredicates = predicates.filterNot {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+ if (nonPartitionPruningPredicates.nonEmpty) {
+ throw new AnalysisException("Expected only partition pruning predicates: " +
+ nonPartitionPruningPredicates)
+ }
+
+ val boundPredicate =
+ InterpretedPredicate.create(predicates.reduce(And).transform {
+ case att: AttributeReference =>
+ val index = partitionSchema.indexWhere(_.name == att.name)
+ BoundReference(index, partitionSchema(index).dataType, nullable = true)
+ })
+
+ inputPartitions.filter { p =>
+ boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
+ }
+ }
+ }
}
object CatalogUtils {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index cdf618aef9..9ca1c71d1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.types.StructType
@@ -556,9 +556,9 @@ class InMemoryCatalog(
table: String,
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
- // TODO: Provide an implementation
- throw new UnsupportedOperationException(
- "listPartitionsByFilter is not implemented for InMemoryCatalog")
+ val catalogTable = getTable(db, table)
+ val allPartitions = listPartitions(db, table)
+ prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId)
}
// --------------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
index 7820f39d96..42db4398e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.catalog
import java.net.URI
+import java.util.TimeZone
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty)
}
+ test("list partitions by filter") {
+ val tz = TimeZone.getDefault.getID
+ val catalog = newBasicCatalog()
+
+ def checkAnswer(
+ table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition])
+ : Unit = {
+
+ assertResult(expected.map(_.spec)) {
+ catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz)
+ .map(_.spec).toSet
+ }
+ }
+
+ val tbl2 = catalog.getTable("db2", "tbl2")
+
+ checkAnswer(tbl2, Seq.empty, Set(part1, part2))
+ checkAnswer(tbl2, Seq('a.int <= 1), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 2), Set.empty)
+ checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2))
+ checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2))
+ checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1))
+ checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty)
+ checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1))
+
+ intercept[AnalysisException] {
+ try {
+ checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty)
+ } catch {
+ // HiveExternalCatalog may be the first one to notice and throw an exception, which will
+ // then be caught and converted to a RuntimeException with a descriptive message.
+ case ex: RuntimeException if ex.getMessage.contains("MetaException") =>
+ throw new AnalysisException(ex.getMessage)
+ }
+ }
+ }
+
test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index bddf5af23e..c350d8bcba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
val table = r.tableMeta
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val cache = sparkSession.sessionState.catalog.tableRelationCache
- val withHiveSupport =
- sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive"
val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
@@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
- // TODO: improve `InMemoryCatalog` and remove this limitation.
- catalogTable = if (withHiveSupport) Some(table) else None)
+ catalogTable = Some(table))
LogicalRelation(
dataSource.resolveRelation(checkFilesExist = false),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 33b21be372..f0e35dff57 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient {
val rawTable = getRawTable(db, table)
val catalogTable = restoreTableMetadata(rawTable)
- val partitionColumnNames = catalogTable.partitionColumnNames.toSet
- val nonPartitionPruningPredicates = predicates.filterNot {
- _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
- }
- if (nonPartitionPruningPredicates.nonEmpty) {
- sys.error("Expected only partition pruning predicates: " +
- predicates.reduceLeft(And))
- }
+ val partColNameMap = buildLowerCasePartColNameMap(catalogTable)
- val partitionSchema = catalogTable.partitionSchema
- val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))
-
- if (predicates.nonEmpty) {
- val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part =>
+ val clientPrunedPartitions =
+ client.getPartitionsByFilter(rawTable, predicates).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
- val boundPredicate =
- InterpretedPredicate.create(predicates.reduce(And).transform {
- case att: AttributeReference =>
- val index = partitionSchema.indexWhere(_.name == att.name)
- BoundReference(index, partitionSchema(index).dataType, nullable = true)
- })
- clientPrunedPartitions.filter { p =>
- boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
- }
- } else {
- client.getPartitions(catalogTable).map { part =>
- part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
- }
- }
+ prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId)
}
// --------------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index d55c41e5c9..2e35f39839 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
*/
def convertFilters(table: Table, filters: Seq[Expression]): String = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
- val varcharKeys = table.getPartitionKeys.asScala
+ lazy val varcharKeys = table.getPartitionKeys.asScala
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) ||
col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME))
.map(col => col.getName).toSet
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
index 4349f1aa23..bd54c043c6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.types.StructType
@@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {
import utils._
- test("list partitions by filter") {
- val catalog = newBasicCatalog()
- val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT")
- assert(selectedPartitions.length == 1)
- assert(selectedPartitions.head.spec == part1.spec)
- }
-
test("SPARK-18647: do not put provider in table properties for Hive serde table") {
val catalog = newBasicCatalog()
val hiveTable = CatalogTable(