aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-07-14 11:22:09 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-14 11:22:09 -0700
commit37f2d9635ff874fb8ad9d246e49faf6098d501c3 (patch)
tree17309014cd6e0a5af5e60522607a7d6b95017231 /sql
parent59d820aa8dec08b744971237860b4c6bef577ddf (diff)
downloadspark-37f2d9635ff874fb8ad9d246e49faf6098d501c3.tar.gz
spark-37f2d9635ff874fb8ad9d246e49faf6098d501c3.tar.bz2
spark-37f2d9635ff874fb8ad9d246e49faf6098d501c3.zip
[SPARK-9027] [SQL] Generalize metastore predicate pushdown
Add support for pushing down metastore filters that are in different orders and add some unit tests. Author: Michael Armbrust <michael@databricks.com> Closes #7386 from marmbrus/metastoreFilters and squashes the following commits: 05a4524 [Michael Armbrust] [SPARK-9027][SQL] Generalize metastore predicate pushdown
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala54
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala78
2 files changed, 107 insertions, 25 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 5542a521b1..d12778c758 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StringType, IntegralType}
/**
@@ -312,37 +312,41 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
- override def getPartitionsByFilter(
- hive: Hive,
- table: Table,
- predicates: Seq[Expression]): Seq[Partition] = {
+ /**
+ * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e.
+ * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...".
+ *
+ * Unsupported predicates are skipped.
+ */
+ def convertFilters(table: Table, filters: Seq[Expression]): String = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet
- // Hive getPartitionsByFilter() takes a string that represents partition
- // predicates like "str_key=\"value\" and int_key=1 ..."
- val filter = predicates.flatMap { expr =>
- expr match {
- case op @ BinaryComparison(lhs, rhs) => {
- lhs match {
- case AttributeReference(_, _, _, _) => {
- rhs.dataType match {
- case _: IntegralType =>
- Some(lhs.prettyString + op.symbol + rhs.prettyString)
- case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
- Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
- case _ => None
- }
- }
- case _ => None
- }
- }
- case _ => None
- }
+ filters.collect {
+ case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) =>
+ s"${a.name} ${op.symbol} $v"
+ case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) =>
+ s"$v ${op.symbol} ${a.name}"
+
+ case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType))
+ if !varcharKeys.contains(a.name) =>
+ s"""${a.name} ${op.symbol} "$v""""
+ case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute)
+ if !varcharKeys.contains(a.name) =>
+ s""""$v" ${op.symbol} ${a.name}"""
}.mkString(" and ")
+ }
+
+ override def getPartitionsByFilter(
+ hive: Hive,
+ table: Table,
+ predicates: Seq[Expression]): Seq[Partition] = {
+ // Hive getPartitionsByFilter() takes a string that represents partition
+ // predicates like "str_key=\"value\" and int_key=1 ..."
+ val filter = convertFilters(table, predicates)
val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
new file mode 100644
index 0000000000..0efcf80bd4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.client
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.metastore.api.FieldSchema
+import org.apache.hadoop.hive.serde.serdeConstants
+
+import org.apache.spark.{Logging, SparkFunSuite}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
+
+/**
+ * A set of tests for the filter conversion logic used when pushing partition pruning into the
+ * metastore
+ */
+class FiltersSuite extends SparkFunSuite with Logging {
+ private val shim = new Shim_v0_13
+
+ private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
+ private val varCharCol = new FieldSchema()
+ varCharCol.setName("varchar")
+ varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
+ testTable.setPartCols(varCharCol :: Nil)
+
+ filterTest("string filter",
+ (a("stringcol", StringType) > Literal("test")) :: Nil,
+ "stringcol > \"test\"")
+
+ filterTest("string filter backwards",
+ (Literal("test") > a("stringcol", StringType)) :: Nil,
+ "\"test\" > stringcol")
+
+ filterTest("int filter",
+ (a("intcol", IntegerType) === Literal(1)) :: Nil,
+ "intcol = 1")
+
+ filterTest("int filter backwards",
+ (Literal(1) === a("intcol", IntegerType)) :: Nil,
+ "1 = intcol")
+
+ filterTest("int and string filter",
+ (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
+ "1 = intcol and \"a\" = strcol")
+
+ filterTest("skip varchar",
+ (Literal("") === a("varchar", StringType)) :: Nil,
+ "")
+
+ private def filterTest(name: String, filters: Seq[Expression], result: String) = {
+ test(name){
+ val converted = shim.convertFilters(testTable, filters)
+ if (converted != result) {
+ fail(
+ s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'")
+ }
+ }
+ }
+
+ private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
+}