aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala58
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala11
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala68
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala8
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala2
9 files changed, 137 insertions, 44 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4b7a782c80..5bdf68c83f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -301,7 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
- val partitions = metastoreRelation.hiveQlPartitions.map { p =>
+ // We're converting the entire table into ParquetRelation, so predicates to Hive metastore
+ // are empty.
+ val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
@@ -644,32 +646,6 @@ private[hive] case class MetastoreRelation
new Table(tTable)
}
- @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
- val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
- tPartition.setDbName(databaseName)
- tPartition.setTableName(tableName)
- tPartition.setValues(p.values)
-
- val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
- tPartition.setSd(sd)
- sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
-
- sd.setLocation(p.storage.location)
- sd.setInputFormat(p.storage.inputFormat)
- sd.setOutputFormat(p.storage.outputFormat)
-
- val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
- sd.setSerdeInfo(serdeInfo)
- serdeInfo.setSerializationLib(p.storage.serde)
-
- val serdeParameters = new java.util.HashMap[String, String]()
- serdeInfo.setParameters(serdeParameters)
- table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
- p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
-
- new Partition(hiveQlTable, tPartition)
- }
-
@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
@@ -690,6 +666,34 @@ private[hive] case class MetastoreRelation
}
)
+ def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
+ table.getPartitions(predicates).map { p =>
+ val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
+ tPartition.setDbName(databaseName)
+ tPartition.setTableName(tableName)
+ tPartition.setValues(p.values)
+
+ val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
+ tPartition.setSd(sd)
+ sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
+
+ sd.setLocation(p.storage.location)
+ sd.setInputFormat(p.storage.inputFormat)
+ sd.setOutputFormat(p.storage.outputFormat)
+
+ val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
+ sd.setSerdeInfo(serdeInfo)
+ serdeInfo.setSerializationLib(p.storage.serde)
+
+ val serdeParameters = new java.util.HashMap[String, String]()
+ serdeInfo.setParameters(serdeParameters)
+ table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
+ p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
+
+ new Partition(hiveQlTable, tPartition)
+ }
+ }
+
/** Only compare database and tablename, not alias. */
override def sameResult(plan: LogicalPlan): Boolean = {
plan match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
index d08c594151..a357bb39ca 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -27,6 +27,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index ed359620a5..9638a8201e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -125,7 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}
- val partitions = relation.hiveQlPartitions.filter { part =>
+ val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
@@ -213,7 +213,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
- HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
+ HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
case _ =>
Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index 0a1d761a52..1656587d14 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -21,6 +21,7 @@ import java.io.PrintStream
import java.util.{Map => JMap}
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
+import org.apache.spark.sql.catalyst.expressions.Expression
private[hive] case class HiveDatabase(
name: String,
@@ -71,7 +72,12 @@ private[hive] case class HiveTable(
def isPartitioned: Boolean = partitionColumns.nonEmpty
- def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
+ def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
+ predicates match {
+ case Nil => client.getAllPartitions(this)
+ case _ => client.getPartitionsByFilter(this, predicates)
+ }
+ }
// Hive does not support backticks when passing names to the client.
def qualifiedName: String = s"$database.$name"
@@ -132,6 +138,9 @@ private[hive] trait ClientInterface {
/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]
+ /** Returns partitions filtered by predicates for the given table. */
+ def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]
+
/** Loads a static partition into an existing table. */
def loadPartition(
loadPath: String,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 9d83ca6c11..1f280c6429 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -17,27 +17,24 @@
package org.apache.spark.sql.hive.client
-import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
-import java.net.URI
-import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
+import java.io.{File, PrintStream}
+import java.util.{Map => JMap}
import javax.annotation.concurrent.GuardedBy
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.util.CircularBuffer
import scala.collection.JavaConversions._
import scala.language.reflectiveCalls
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
-import org.apache.hadoop.hive.metastore.api
-import org.apache.hadoop.hive.metastore.api.FieldSchema
-import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
-import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.processors._
-import org.apache.hadoop.hive.ql.Driver
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.{Driver, metadata}
import org.apache.spark.Logging
import org.apache.spark.sql.execution.QueryExecutionException
@@ -316,6 +313,13 @@ private[hive] class ClientWrapper(
shim.getAllPartitions(client, qlTable).map(toHivePartition)
}
+ override def getPartitionsByFilter(
+ hTable: HiveTable,
+ predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
+ val qlTable = toQlTable(hTable)
+ shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
+ }
+
override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables(dbName)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 1fa9d278e2..5542a521b1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde.serdeConstants
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
+import org.apache.spark.sql.types.{StringType, IntegralType}
/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
@@ -61,6 +66,8 @@ private[client] sealed abstract class Shim {
def getAllPartitions(hive: Hive, table: Table): Seq[Partition]
+ def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]
+
def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor
def getDriverResults(driver: Driver): Seq[String]
@@ -109,7 +116,7 @@ private[client] sealed abstract class Shim {
}
-private[client] class Shim_v0_12 extends Shim {
+private[client] class Shim_v0_12 extends Shim with Logging {
private lazy val startMethod =
findStaticMethod(
@@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
+ override def getPartitionsByFilter(
+ hive: Hive,
+ table: Table,
+ predicates: Seq[Expression]): Seq[Partition] = {
+ // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
+ // See HIVE-4888.
+ logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
+ "Please use Hive 0.13 or higher.")
+ getAllPartitions(hive, table)
+ }
+
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
@@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
classOf[Hive],
"getAllPartitionsOf",
classOf[Table])
+ private lazy val getPartitionsByFilterMethod =
+ findMethod(
+ classOf[Hive],
+ "getPartitionsByFilter",
+ classOf[Table],
+ classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
@@ -288,6 +312,48 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq
+ override def getPartitionsByFilter(
+ hive: Hive,
+ table: Table,
+ predicates: Seq[Expression]): Seq[Partition] = {
+ // hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
+ val varcharKeys = table.getPartitionKeys
+ .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
+ .map(col => col.getName).toSet
+
+ // Hive getPartitionsByFilter() takes a string that represents partition
+ // predicates like "str_key=\"value\" and int_key=1 ..."
+ val filter = predicates.flatMap { expr =>
+ expr match {
+ case op @ BinaryComparison(lhs, rhs) => {
+ lhs match {
+ case AttributeReference(_, _, _, _) => {
+ rhs.dataType match {
+ case _: IntegralType =>
+ Some(lhs.prettyString + op.symbol + rhs.prettyString)
+ case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
+ Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
+ case _ => None
+ }
+ }
+ case _ => None
+ }
+ }
+ case _ => None
+ }
+ }.mkString(" and ")
+
+ val partitions =
+ if (filter.isEmpty) {
+ getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
+ } else {
+ logDebug(s"Hive metastore filter is '$filter'.")
+ getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
+ }
+
+ partitions.toSeq
+ }
+
override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index d33da8242c..ba7eb15a1c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -44,7 +44,7 @@ private[hive]
case class HiveTableScan(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
- partitionPruningPred: Option[Expression])(
+ partitionPruningPred: Seq[Expression])(
@transient val context: HiveContext)
extends LeafNode {
@@ -56,7 +56,7 @@ case class HiveTableScan(
// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
- private[this] val boundPruningPred = partitionPruningPred.map { pred =>
+ private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
@@ -133,7 +133,8 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
- hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
+ hadoopReader.makeRDDForPartitionedTable(
+ prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}
override def output: Seq[Attribute] = attributes
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index d52e162acb..3eb127e23d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client
import java.io.File
import org.apache.spark.{Logging, SparkFunSuite}
+import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.Utils
/**
@@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.getAllPartitions(client.getTable("default", "src_part"))
}
+ test(s"$version: getPartitionsByFilter") {
+ client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
+ AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
+ Literal(1))))
+ }
+
test(s"$version: loadPartition") {
client.loadPartition(
emptyDir,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index de6a41ce5b..e83a7dc77e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
- p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
+ p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
}