diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-05-14 00:14:59 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2015-05-14 00:14:59 +0800 |
commit | 0da254fb2903c01e059fa7d0dc81df5740312b35 (patch) | |
tree | c99831581e4987626b950f0df96ce36a26e0ee53 /sql | |
parent | aa6ba3f2166edcc8bcda3abc70482fa8605e83b7 (diff) | |
download | spark-0da254fb2903c01e059fa7d0dc81df5740312b35.tar.gz spark-0da254fb2903c01e059fa7d0dc81df5740312b35.tar.bz2 spark-0da254fb2903c01e059fa7d0dc81df5740312b35.zip |
[SPARK-6734] [SQL] Add UDTF.close support in Generate
Some third-party UDTF extensions generate additional rows in the "GenericUDTF.close()" method, which is supported / documented by Hive.
https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
However, Spark SQL ignores the "GenericUDTF.close()", and it causes bug while porting job from Hive to Spark SQL.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #5383 from chenghao-intel/udtf_close and squashes the following commits:
98b4e4b [Cheng Hao] Support UDTF.close
Diffstat (limited to 'sql')
7 files changed, 74 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9a6cb048af..747a47bdde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -56,6 +56,12 @@ abstract class Generator extends Expression { /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] + + /** + * Notifies that there are no more rows to process, clean up code, and additional + * rows can be made here. + */ + def terminate(): TraversableOnce[Row] = Nil } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 08d9079335..dd02c1f457 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -22,6 +22,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ /** + * For lazy computing, be sure the generator.terminate() called in the very last + * TODO reusing the CompletionIterator? + */ +private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row]) + extends Iterator[Row] { + + lazy val results = func().toIterator + override def hasNext: Boolean = results.hasNext + override def next(): Row = results.next() +} + +/** * :: DeveloperApi :: * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional @@ -47,27 +59,33 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[Row] = { + // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) - // Used to produce rows with no matches when outer = true. - val outerProjection = - newProjection(child.output ++ nullValues, child.output) - - val joinProjection = newProjection(output, output) + val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow - iter.flatMap {row => + iter.flatMap { row => + // we should always set the left (child output) + joinedRow.withLeft(row) val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { - outerProjection(row) :: Nil + joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinProjection(joinedRow(row, or))) + outputRows.map(or => joinedRow.withRight(or)) } + } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + // we leave the left side as the last element of its child output + // keep it the same as Hive does + joinedRow.withRight(row) } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) + child.execute().mapPartitions { iter => + iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate()) + } } } } + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fd0b6f0585..bc6b3a2d58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -483,7 +483,11 @@ private[hive] case class HiveGenericUdtf( extends Generator with HiveInspectors { @transient - protected lazy val function: GenericUDTF = funcWrapper.createFunction() + protected lazy val function: GenericUDTF = { + val fun: GenericUDTF = funcWrapper.createFunction() + fun.setCollector(collector) + fun + } @transient protected lazy val inputInspectors = children.map(toInspector) @@ -494,6 +498,9 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) + @transient + protected lazy val collector = new UDTFCollector + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } @@ -502,8 +509,7 @@ private[hive] case class HiveGenericUdtf( outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - val collector = new UDTFCollector - function.setCollector(collector) + function.process(wrap(inputProjection(input), inputInspectors, udtInput)) collector.collectRows() } @@ -525,6 +531,12 @@ private[hive] case class HiveGenericUdtf( } } + override def terminate(): TraversableOnce[Row] = { + outputInspector // Make sure initialized. + function.close() + collector.collectRows() + } + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } diff --git a/sql/hive/src/test/resources/TestUDTF.jar b/sql/hive/src/test/resources/TestUDTF.jar Binary files differnew file mode 100644 index 0000000000..514f2d5d26 --- /dev/null +++ b/sql/hive/src/test/resources/TestUDTF.jar diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 new file mode 100644 index 0000000000..946e72fc87 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 @@ -0,0 +1,2 @@ +97 500 +97 500 diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 new file mode 100644 index 0000000000..a5c8806279 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 @@ -0,0 +1,2 @@ +3 +3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2c9c08a9f3..089a57e25c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, StructObjectInspector, ObjectInspector} import org.scalatest.BeforeAndAfter import scala.util.Try @@ -51,14 +54,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION udtf_count2") } + createQueryTest("Test UDTF.close in Lateral Views", + """ + |SELECT key, cc + |FROM src LATERAL VIEW udtf_count2(value) dd AS cc + """.stripMargin, false) // false mean we have to keep the temp function in registry + + createQueryTest("Test UDTF.close in SELECT", + "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) + test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") |