aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-28 16:04:48 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-07-28 16:04:48 -0700
commit59b92add7cc9cca1eaf0c558edb7c4add66c284f (patch)
treef10a95c6c554c4f674c9940e3baddf47deef32b1 /sql
parent21825529eae66293ec5d8638911303fa54944dd5 (diff)
downloadspark-59b92add7cc9cca1eaf0c558edb7c4add66c284f.tar.gz
spark-59b92add7cc9cca1eaf0c558edb7c4add66c284f.tar.bz2
spark-59b92add7cc9cca1eaf0c558edb7c4add66c284f.zip
[SPARK-9393] [SQL] Fix several error-handling bugs in ScriptTransform operator
SparkSQL's ScriptTransform operator has several serious bugs which make debugging fairly difficult: - If exceptions are thrown in the writing thread then the child process will not be killed, leading to a deadlock because the reader thread will block while waiting for input that will never arrive. - TaskContext is not propagated to the writer thread, which may cause errors in upstream pipelined operators. - Exceptions which occur in the writer thread are not propagated to the main reader thread, which may cause upstream errors to be silently ignored instead of killing the job. This can lead to silently incorrect query results. - The writer thread is not a daemon thread, but it should be. In addition, the code in this file is extremely messy: - Lots of fields are nullable but the nullability isn't clearly explained. - Many confusing variable names: for instance, there are variables named `ite` and `iterator` that are defined in the same scope. - Some code was misindented. - The `*serdeClass` variables are actually expected to be single-quoted strings, which is really confusing: I feel that this parsing / extraction should be performed in the analyzer, not in the operator itself. - There were no unit tests for the operator itself, only end-to-end tests. This pull request addresses these issues, borrowing some error-handling techniques from PySpark's PythonRDD. Author: Josh Rosen <joshrosen@databricks.com> Closes #7710 from JoshRosen/script-transform and squashes the following commits: 16c44e2 [Josh Rosen] Update some comments 983f200 [Josh Rosen] Use unescapeSQLString instead of stripQuotes 6a06a8c [Josh Rosen] Clean up handling of quotes in serde class name 494cde0 [Josh Rosen] Propagate TaskContext to writer thread 323bb2b [Josh Rosen] Fix error-swallowing bug b31258d [Josh Rosen] Rename iterator variables to disambiguate. 88278de [Josh Rosen] Split ScriptTransformation writer thread into own class. 8b162b6 [Josh Rosen] Add failing test which demonstrates exception masking issue 4ee36a2 [Josh Rosen] Kill script transform subprocess when error occurs in input writer. bd4c948 [Josh Rosen] Skip launching of external command for empty partitions. b43e4ec [Josh Rosen] Clean up nullability in ScriptTransformation fa18d26 [Josh Rosen] Add basic unit test for script transform with 'cat' command.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala27
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala280
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala123
4 files changed, 317 insertions, 123 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 6a8f394545..f46855edfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -33,11 +33,13 @@ import scala.util.control.NonFatal
*/
class SparkPlanTest extends SparkFunSuite {
+ protected def sqlContext: SQLContext = TestSQLContext
+
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
- TestSQLContext.implicits.localSeqToDataFrameHolder(data)
+ sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ SparkPlanTest.checkAnswer(
+ input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -147,13 +150,14 @@ object SparkPlanTest {
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan)
+ executePlan(expectedOutputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -168,7 +172,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -207,12 +211,13 @@ object SparkPlanTest {
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -275,10 +280,10 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ val resolvedPlan = sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2f79b0aad0..e6df64d264 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
def matchSerDe(clause: Seq[ASTNode])
- : (Seq[(String, String)], String, Seq[(String, String)]) = clause match {
+ : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match {
case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
val rowFormat = propsClause.map {
case Token(name, Token(value, Nil) :: Nil) => (name, value)
}
- (rowFormat, "", Nil)
+ (rowFormat, None, Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
- (Nil, serdeClass, Nil)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
Token("TOK_TABLEPROPERTIES",
@@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) =>
(name, value)
}
- (Nil, serdeClass, serdeProps)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
- case Nil => (Nil, "", Nil)
+ case Nil => (Nil, None, Nil)
}
val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 205e622195..741c705e2a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,15 +17,18 @@
package org.apache.spark.sql.hive.execution
-import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
+import java.io._
import java.util.Properties
+import javax.annotation.Nullable
import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
@@ -56,21 +59,53 @@ case class ScriptTransformation(
override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
protected override def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitions { iter =>
+ def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd)
- // We need to start threads connected to the process pipeline:
- // 1) The error msg generated by the script process would be hidden.
- // 2) If the error msg is too big to chock up the buffer, the input logic would be hung
+
val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
- val reader = new BufferedReader(new InputStreamReader(inputStream))
- val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
+ // In order to avoid deadlocks, we need to consume the error output of the child process.
+ // To avoid issues caused by large error output, we use a circular buffer to limit the amount
+ // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
+ // that motivates this.
+ val stderrBuffer = new CircularBuffer(2048)
+ new RedirectThread(
+ errorStream,
+ stderrBuffer,
+ "Thread-ScriptTransformation-STDERR-Consumer").start()
+
+ val outputProjection = new InterpretedProjection(input, child.output)
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
+
+ // This new thread will consume the ScriptTransformation's input rows and write them to the
+ // external process. That process's output will be read by this current thread.
+ val writerThread = new ScriptTransformationWriterThread(
+ inputIterator,
+ outputProjection,
+ inputSerde,
+ inputSoi,
+ ioschema,
+ outputStream,
+ proc,
+ stderrBuffer,
+ TaskContext.get()
+ )
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (outputSerde, outputSoi) = {
+ ioschema.initOutputSerDe(output).getOrElse((null, null))
+ }
- val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
+ val reader = new BufferedReader(new InputStreamReader(inputStream))
+ val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var cacheRow: InternalRow = null
var curLine: String = null
var eof: Boolean = false
@@ -79,12 +114,26 @@ case class ScriptTransformation(
if (outputSerde == null) {
if (curLine == null) {
curLine = reader.readLine()
- curLine != null
+ if (curLine == null) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
} else {
true
}
} else {
- !eof
+ if (eof) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
}
}
@@ -110,11 +159,11 @@ case class ScriptTransformation(
}
i += 1
})
- return mutableRow
+ mutableRow
} catch {
case e: EOFException =>
eof = true
- return null
+ null
}
}
@@ -146,49 +195,83 @@ case class ScriptTransformation(
}
}
- val (inputSerde, inputSoi) = ioschema.initInputSerDe(input)
- val dataOutputStream = new DataOutputStream(outputStream)
- val outputProjection = new InterpretedProjection(input, child.output)
+ writerThread.start()
- // TODO make the 2048 configurable?
- val stderrBuffer = new CircularBuffer(2048)
- // Consume the error stream from the pipeline, otherwise it will be blocked if
- // the pipeline is full.
- new RedirectThread(errorStream, // input stream from the pipeline
- stderrBuffer, // output to a circular buffer
- "Thread-ScriptTransformation-STDERR-Consumer").start()
+ outputIterator
+ }
- // Put the write(output to the pipeline) into a single thread
- // and keep the collector as remain in the main thread.
- // otherwise it will causes deadlock if the data size greater than
- // the pipeline / buffer capacity.
- new Thread(new Runnable() {
- override def run(): Unit = {
- Utils.tryWithSafeFinally {
- iter
- .map(outputProjection)
- .foreach { row =>
- if (inputSerde == null) {
- val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
-
- outputStream.write(data)
- } else {
- val writable = inputSerde.serialize(
- row.asInstanceOf[GenericInternalRow].values, inputSoi)
- prepareWritable(writable).write(dataOutputStream)
- }
- }
- outputStream.close()
- } {
- if (proc.waitFor() != 0) {
- logError(stderrBuffer.toString) // log the stderr circular buffer
- }
- }
- }
- }, "Thread-ScriptTransformation-Feed").start()
+ child.execute().mapPartitions { iter =>
+ if (iter.hasNext) {
+ processIterator(iter)
+ } else {
+ // If the input iterator has no rows then do not launch the external script.
+ Iterator.empty
+ }
+ }
+ }
+}
- iterator
+private class ScriptTransformationWriterThread(
+ iter: Iterator[InternalRow],
+ outputProjection: Projection,
+ @Nullable inputSerde: AbstractSerDe,
+ @Nullable inputSoi: ObjectInspector,
+ ioschema: HiveScriptIOSchema,
+ outputStream: OutputStream,
+ proc: Process,
+ stderrBuffer: CircularBuffer,
+ taskContext: TaskContext
+ ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
+
+ setDaemon(true)
+
+ @volatile private var _exception: Throwable = null
+
+ /** Contains the exception thrown while writing the parent iterator to the external process. */
+ def exception: Option[Throwable] = Option(_exception)
+
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ TaskContext.setTaskContext(taskContext)
+
+ val dataOutputStream = new DataOutputStream(outputStream)
+
+ // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
+ // let's use a variable to record whether the `finally` block was hit due to an exception
+ var threwException: Boolean = true
+ try {
+ iter.map(outputProjection).foreach { row =>
+ if (inputSerde == null) {
+ val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+ outputStream.write(data)
+ } else {
+ val writable = inputSerde.serialize(
+ row.asInstanceOf[GenericInternalRow].values, inputSoi)
+ prepareWritable(writable).write(dataOutputStream)
+ }
+ }
+ outputStream.close()
+ threwException = false
+ } catch {
+ case NonFatal(e) =>
+ // An error occurred while writing input, so kill the child process. According to the
+ // Javadoc this call will not throw an exception:
+ _exception = e
+ proc.destroy()
+ throw e
+ } finally {
+ try {
+ if (proc.waitFor() != 0) {
+ logError(stderrBuffer.toString) // log the stderr circular buffer
+ }
+ } catch {
+ case NonFatal(exceptionFromFinallyBlock) =>
+ if (!threwException) {
+ throw exceptionFromFinallyBlock
+ } else {
+ log.error("Exception in finally block", exceptionFromFinallyBlock)
+ }
+ }
}
}
}
@@ -200,33 +283,43 @@ private[hive]
case class HiveScriptIOSchema (
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
- inputSerdeClass: String,
- outputSerdeClass: String,
+ inputSerdeClass: Option[String],
+ outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
- val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"),
- ("TOK_TABLEROWFORMATLINES", "\n"))
+ private val defaultFormat = Map(
+ ("TOK_TABLEROWFORMATFIELD", "\t"),
+ ("TOK_TABLEROWFORMATLINES", "\n")
+ )
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
- def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(input)
- val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
- (serde, initInputSoi(serde, columns, columnTypes))
+ def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
+ inputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(input)
+ val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
+ val fieldObjectInspectors = columnTypes.map(toInspector)
+ val objectInspector = ObjectInspectorFactory
+ .getStandardStructObjectInspector(columns, fieldObjectInspectors)
+ .asInstanceOf[ObjectInspector]
+ (serde, objectInspector)
+ }
}
- def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(output)
- val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps)
- (serde, initOutputputSoi(serde))
+ def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
+ outputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(output)
+ val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
+ val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
+ (serde, structObjectInspector)
+ }
}
- def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+ private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
val columns = attrs.map {
case aref: AttributeReference => aref.name
case e: NamedExpression => e.name
@@ -242,52 +335,25 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}
- def initSerDe(serdeClassName: String, columns: Seq[String],
- columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
+ private def initSerDe(
+ serdeClassName: String,
+ columns: Seq[String],
+ columnTypes: Seq[DataType],
+ serdeProps: Seq[(String, String)]): AbstractSerDe = {
- val serde: AbstractSerDe = if (serdeClassName != "") {
- val trimed_class = serdeClassName.split("'")(1)
- Utils.classForName(trimed_class)
- .newInstance.asInstanceOf[AbstractSerDe]
- } else {
- null
- }
+ val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
- if (serde != null) {
- val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
+ val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
- var propsMap = serdeProps.map(kv => {
- (kv._1.split("'")(1), kv._2.split("'")(1))
- }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
- propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
+ var propsMap = serdeProps.map(kv => {
+ (kv._1.split("'")(1), kv._2.split("'")(1))
+ }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
+ propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
- val properties = new Properties()
- properties.putAll(propsMap)
- serde.initialize(null, properties)
- }
+ val properties = new Properties()
+ properties.putAll(propsMap)
+ serde.initialize(null, properties)
serde
}
-
- def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType])
- : ObjectInspector = {
-
- if (inputSerde != null) {
- val fieldObjectInspectors = columnTypes.map(toInspector(_))
- ObjectInspectorFactory
- .getStandardStructObjectInspector(columns, fieldObjectInspectors)
- .asInstanceOf[ObjectInspector]
- } else {
- null
- }
- }
-
- def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
- if (outputSerde != null) {
- outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
- } else {
- null
- }
- }
}
-
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
new file mode 100644
index 0000000000..0875232aed
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.scalatest.exceptions.TestFailedException
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.types.StringType
+
+class ScriptTransformationSuite extends SparkPlanTest {
+
+ override def sqlContext: SQLContext = TestHive
+
+ private val noSerdeIOSchema = HiveScriptIOSchema(
+ inputRowFormat = Seq.empty,
+ outputRowFormat = Seq.empty,
+ inputSerdeClass = None,
+ outputSerdeClass = None,
+ inputSerdeProps = Seq.empty,
+ outputSerdeProps = Seq.empty,
+ schemaLess = false
+ )
+
+ private val serdeIOSchema = noSerdeIOSchema.copy(
+ inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
+ outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
+ )
+
+ test("cat without SerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("cat with LazySimpleSerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("script transformation should not swallow errors from upstream operators (no serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+
+ test("script transformation should not swallow errors from upstream operators (with serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+}
+
+private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().map { x =>
+ assert(TaskContext.get() != null) // Make sure that TaskContext is defined.
+ Thread.sleep(1000) // This sleep gives the external process time to start.
+ throw new IllegalArgumentException("intentional exception")
+ }
+ }
+ override def output: Seq[Attribute] = child.output
+}