aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala97
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala25
3 files changed, 96 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 36b3b956da..604914e547 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -116,7 +116,7 @@ case class Aggregate(
*/
@transient
private[this] lazy val resultMap =
- (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
+ (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap
/**
* Substituted version of aggregateExpressions expressions which are used to compute final
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
index 96faebc5a8..f141139ef4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
@@ -18,15 +18,18 @@
package org.apache.spark.sql.hive.execution
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
+import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.MetaStoreUtils
import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive}
import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc}
-import org.apache.hadoop.hive.serde2.Serializer
+import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
+import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Serializer}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred._
@@ -37,6 +40,7 @@ import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
import org.apache.spark.{TaskContext, SparkException}
+import org.apache.spark.util.MutablePair
/* Implicits */
import scala.collection.JavaConversions._
@@ -94,7 +98,7 @@ case class HiveTableScan(
(_: Any, partitionKeys: Array[String]) => {
val value = partitionKeys(ordinal)
val dataType = relation.partitionKeys(ordinal).dataType
- castFromString(value, dataType)
+ unwrapHiveData(castFromString(value, dataType))
}
} else {
val ref = objectInspector.getAllStructFieldRefs
@@ -102,16 +106,55 @@ case class HiveTableScan(
.getOrElse(sys.error(s"Can't find attribute $a"))
(row: Any, _: Array[String]) => {
val data = objectInspector.getStructFieldData(row, ref)
- unwrapData(data, ref.getFieldObjectInspector)
+ unwrapHiveData(unwrapData(data, ref.getFieldObjectInspector))
}
}
}
}
+ private def unwrapHiveData(value: Any) = value match {
+ case maybeNull: String if maybeNull.toLowerCase == "null" => null
+ case varchar: HiveVarchar => varchar.getValue
+ case decimal: HiveDecimal => BigDecimal(decimal.bigDecimalValue)
+ case other => other
+ }
+
private def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
}
+ private def addColumnMetadataToConf(hiveConf: HiveConf) {
+ // Specifies IDs and internal names of columns to be scanned.
+ val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer)
+ val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")
+
+ if (attributes.size == relation.output.size) {
+ ColumnProjectionUtils.setFullyReadColumns(hiveConf)
+ } else {
+ ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
+ }
+
+ ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name))
+
+ // Specifies types and object inspectors of columns to be scanned.
+ val structOI = ObjectInspectorUtils
+ .getStandardObjectInspector(
+ relation.tableDesc.getDeserializer.getObjectInspector,
+ ObjectInspectorCopyOption.JAVA)
+ .asInstanceOf[StructObjectInspector]
+
+ val columnTypeNames = structOI
+ .getAllStructFieldRefs
+ .map(_.getFieldObjectInspector)
+ .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName)
+ .mkString(",")
+
+ hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames)
+ hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames)
+ }
+
+ addColumnMetadataToConf(sc.hiveconf)
+
@transient
def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
@@ -143,20 +186,42 @@ case class HiveTableScan(
}
def execute() = {
- inputRdd.map { row =>
- val values = row match {
- case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) =>
- attributeFunctions.map(_(deserializedRow, partitionKeys))
- case deserializedRow: AnyRef =>
- attributeFunctions.map(_(deserializedRow, Array.empty))
+ inputRdd.mapPartitions { iterator =>
+ if (iterator.isEmpty) {
+ Iterator.empty
+ } else {
+ val mutableRow = new GenericMutableRow(attributes.length)
+ val mutablePair = new MutablePair[Any, Array[String]]()
+ val buffered = iterator.buffered
+
+ // NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
+ // matching are avoided intentionally.
+ val rowsAndPartitionKeys = buffered.head match {
+ // With partition keys
+ case _: Array[Any] =>
+ buffered.map { case array: Array[Any] =>
+ val deserializedRow = array(0)
+ val partitionKeys = array(1).asInstanceOf[Array[String]]
+ mutablePair.update(deserializedRow, partitionKeys)
+ }
+
+ // Without partition keys
+ case _ =>
+ val emptyPartitionKeys = Array.empty[String]
+ buffered.map { deserializedRow =>
+ mutablePair.update(deserializedRow, emptyPartitionKeys)
+ }
+ }
+
+ rowsAndPartitionKeys.map { pair =>
+ var i = 0
+ while (i < attributes.length) {
+ mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
+ i += 1
+ }
+ mutableRow: Row
+ }
}
- buildRow(values.map {
- case n: String if n.toLowerCase == "null" => null
- case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue
- case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal =>
- BigDecimal(decimal.bigDecimalValue)
- case other => other
- })
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index edff38b901..1b5a132f96 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.hive.execution
import java.io._
+import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
+
import org.apache.spark.sql.Logging
-import org.apache.spark.sql.catalyst.plans.logical.{ExplainCommand, NativeCommand}
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.Sort
-import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen}
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -128,17 +129,19 @@ abstract class HiveComparisonTest
protected def prepareAnswer(
hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {
+
+ def isSorted(plan: LogicalPlan): Boolean = plan match {
+ case _: Join | _: Aggregate | _: BaseRelation | _: Generate | _: Sample | _: Distinct => false
+ case PhysicalOperation(_, _, Sort(_, _)) => true
+ case _ => plan.children.iterator.map(isSorted).exists(_ == true)
+ }
+
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
case _: ExplainCommand => answer
- case _ =>
- // TODO: Really we only care about the final total ordering here...
- val isOrdered = hiveQuery.executedPlan.collect {
- case s @ Sort(_, global, _) if global => s
- }.nonEmpty
- // If the query results aren't sorted, then sort them to ensure deterministic answers.
- if (!isOrdered) answer.sorted else answer
+ case plan if isSorted(plan) => answer
+ case _ => answer.sorted
}
orderedAnswer.map(cleanPaths)
}
@@ -161,7 +164,7 @@ abstract class HiveComparisonTest
"minFileSize"
)
protected def nonDeterministicLine(line: String) =
- nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_)
+ nonDeterministicLineIndicators.exists(line contains _)
/**
* Removes non-deterministic paths from `str` so cached answers will compare correctly.