diff options
Diffstat (limited to 'sql/hive')
6 files changed, 55 insertions, 39 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fc053c56c0..c36b5878cb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -115,23 +117,31 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { case p: LogicalPlan if !p.childrenResolved => p case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => - val childOutputDataTypes = child.output.map(_.dataType) - // Only check attributes, not partitionKeys since they are always strings. - // TODO: Fully support inserting into partitioned tables. - val tableOutputDataTypes = table.attributes.map(_.dataType) - - if (childOutputDataTypes == tableOutputDataTypes) { - p - } else { - // Only do the casting when child output data types differ from table output data types. - val castedChildOutput = child.output.zip(table.output).map { - case (input, output) if input.dataType != output.dataType => - Alias(Cast(input, output.dataType), input.name)() - case (input, _) => input - } - - p.copy(child = logical.Project(castedChildOutput, child)) + castChildOutput(p, table, child) + + case p @ logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( + _, HiveTableScan(_, table, _))), _, child, _) => + castChildOutput(p, table, child) + } + + def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { + val childOutputDataTypes = child.output.map(_.dataType) + // Only check attributes, not partitionKeys since they are always strings. + // TODO: Fully support inserting into partitioned tables. + val tableOutputDataTypes = table.attributes.map(_.dataType) + + if (childOutputDataTypes == tableOutputDataTypes) { + p + } else { + // Only do the casting when child output data types differ from table output data types. + val castedChildOutput = child.output.zip(table.output).map { + case (input, output) if input.dataType != output.dataType => + Alias(Cast(input, output.dataType), input.name)() + case (input, _) => input } + + p.copy(child = logical.Project(castedChildOutput, child)) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 3ca1d93c11..ac817b21a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -42,6 +43,9 @@ trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.InsertIntoTable(SparkLogicalPlan(InMemoryColumnarTableScan( + _, HiveTableScan(_, table, _))), partition, child, overwrite) => + InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 2fea970295..465e5f146f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -160,12 +160,6 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { TestTable("src1", "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), - TestTable("dest1", - "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd), - TestTable("dest2", - "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd), - TestTable("dest3", - "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd), TestTable("srcpart", () => { runSqlHive( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") @@ -257,6 +251,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { private val loadedTables = new collection.mutable.HashSet[String] + var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. @@ -265,6 +260,9 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) + + if (cacheTables) + cacheTable(name) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index f9b437d435..55a4363af6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -130,8 +130,7 @@ trait HiveFunctionFactory { } } -abstract class HiveUdf - extends Expression with Logging with HiveFunctionFactory { +abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { self: Product => type UDFType @@ -146,7 +145,7 @@ abstract class HiveUdf lazy val functionInfo = getFunctionInfo(name) lazy val function = createFunction[UDFType](name) - override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})" + override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})" } case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { @@ -202,10 +201,11 @@ case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUd } } -case class HiveGenericUdf( - name: String, - children: Seq[Expression]) extends HiveUdf with HiveInspectors { +case class HiveGenericUdf(name: String, children: Seq[Expression]) + extends HiveUdf with HiveInspectors { + import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ + type UDFType = GenericUDF @transient @@ -357,7 +357,7 @@ case class HiveGenericUdaf( override def toString = s"$nodeName#$name(${children.mkString(",")})" - def newInstance = new HiveUdafFunction(name, children, this) + def newInstance() = new HiveUdafFunction(name, children, this) } /** @@ -435,7 +435,7 @@ case class HiveGenericUdtf( } } - override def toString() = s"$nodeName#$name(${children.mkString(",")})" + override def toString = s"$nodeName#$name(${children.mkString(",")})" } case class HiveUdafFunction( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 3cc4562a88..6c91f40d0f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -218,10 +218,7 @@ abstract class HiveComparisonTest val quotes = "\"\"\"" queryList.zipWithIndex.map { case (query, i) => - s""" - |val q$i = $quotes$query$quotes.q - |q$i.stringResult() - """.stripMargin + s"""val q$i = hql($quotes$query$quotes); q$i.collect()""" }.mkString("\n== Console version of this test ==\n", "\n", "\n") } @@ -287,7 +284,6 @@ abstract class HiveComparisonTest |Error: ${e.getMessage} |${stackTraceToString(e)} |$queryString - |$consoleTestCase """.stripMargin stringToFile( new File(hiveFailedDirectory, testCaseName), @@ -304,7 +300,7 @@ abstract class HiveComparisonTest val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHive.HiveQLQueryExecution(queryString) try { (query, prepareAnswer(query, query.stringResult())) } catch { - case e: Exception => + case e: Throwable => val errorMessage = s""" |Failed to execute query using catalyst: @@ -313,8 +309,6 @@ abstract class HiveComparisonTest |$query |== HIVE - ${hive.size} row(s) == |${hive.mkString("\n")} - | - |$consoleTestCase """.stripMargin stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) fail(errorMessage) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f76e16bc1a..c3cfa3d25a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -17,16 +17,26 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.hive.TestHive /** * Runs the test cases that are included in the hive distribution. */ -class HiveCompatibilitySuite extends HiveQueryFileTest { +class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath lazy val hiveQueryDir = TestHive.getHiveFile("ql/src/test/queries/clientpositive") def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + override def beforeAll() { + TestHive.cacheTables = true + } + + override def afterAll() { + TestHive.cacheTables = false + } + /** A list of tests deemed out of scope currently and thus completely disregarded. */ override def blackList = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. |