diff options
Diffstat (limited to 'sql')
9 files changed, 137 insertions, 44 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4b7a782c80..5bdf68c83f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -301,7 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - val partitions = metastoreRelation.hiveQlPartitions.map { p => + // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // are empty. + val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) @@ -644,32 +646,6 @@ private[hive] case class MetastoreRelation new Table(tTable) } - @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) - - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Partition(hiveQlTable, tPartition) - } - @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) @@ -690,6 +666,34 @@ private[hive] case class MetastoreRelation } ) + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { + table.getPartitions(predicates).map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + } + /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index d08c594151..a357bb39ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -27,6 +27,7 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index ed359620a5..9638a8201e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies { InterpretedPredicate.create(castedPredicate) } - val partitions = relation.hiveQlPartitions.filter { part => + val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 0a1d761a52..1656587d14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,6 +21,7 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} +import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -71,7 +72,12 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { + predicates match { + case Nil => client.getAllPartitions(this) + case _ => client.getPartitionsByFilter(this, predicates) + } + } // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" @@ -132,6 +138,9 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] + /** Returns partitions filtered by predicates for the given table. */ + def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] + /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 9d83ca6c11..1f280c6429 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,27 +17,24 @@ package org.apache.spark.sql.hive.client -import java.io.{BufferedReader, InputStreamReader, File, PrintStream} -import java.net.URI -import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import java.io.{File, PrintStream} +import java.util.{Map => JMap} import javax.annotation.concurrent.GuardedBy +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.CircularBuffer import scala.collection.JavaConversions._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} -import org.apache.hadoop.hive.metastore.api -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.Driver +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.spark.Logging import org.apache.spark.sql.execution.QueryExecutionException @@ -316,6 +313,13 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } + override def getPartitionsByFilter( + hTable: HiveTable, + predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { + val qlTable = toQlTable(hTable) + shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) + } + override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 1fa9d278e2..5542a521b1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison} +import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -61,6 +66,8 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] + def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] + def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -109,7 +116,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim { +private[client] class Shim_v0_12 extends Shim with Logging { private lazy val startMethod = findStaticMethod( @@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. + // See HIVE-4888. + logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + + "Please use Hive 0.13 or higher.") + getAllPartitions(hive, table) + } + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) + private lazy val getPartitionsByFilterMethod = + findMethod( + classOf[Hive], + "getPartitionsByFilter", + classOf[Table], + classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -288,6 +312,48 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + override def getPartitionsByFilter( + hive: Hive, + table: Table, + predicates: Seq[Expression]): Seq[Partition] = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + val varcharKeys = table.getPartitionKeys + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + // Hive getPartitionsByFilter() takes a string that represents partition + // predicates like "str_key=\"value\" and int_key=1 ..." + val filter = predicates.flatMap { expr => + expr match { + case op @ BinaryComparison(lhs, rhs) => { + lhs match { + case AttributeReference(_, _, _, _) => { + rhs.dataType match { + case _: IntegralType => + Some(lhs.prettyString + op.symbol + rhs.prettyString) + case _: StringType if (!varcharKeys.contains(lhs.prettyString)) => + Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"") + case _ => None + } + } + case _ => None + } + } + case _ => None + } + }.mkString(" and ") + + val partitions = + if (filter.isEmpty) { + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] + } else { + logDebug(s"Hive metastore filter is '$filter'.") + getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] + } + + partitions.toSeq + } + override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index d33da8242c..ba7eb15a1c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Option[Expression])( + partitionPruningPred: Seq[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.map { pred => + private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -133,7 +133,8 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d52e162acb..3eb127e23d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client import java.io.File import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } + test(s"$version: getPartitionsByFilter") { + client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( + AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), + Literal(1)))) + } + test(s"$version: loadPartition") { client.loadPartition( emptyDir, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index de6a41ce5b..e83a7dc77e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty } |