aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala27
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala280
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala123
4 files changed, 317 insertions, 123 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 6a8f394545..f46855edfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -33,11 +33,13 @@ import scala.util.control.NonFatal
*/
class SparkPlanTest extends SparkFunSuite {
+ protected def sqlContext: SQLContext = TestSQLContext
+
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
- TestSQLContext.implicits.localSeqToDataFrameHolder(data)
+ sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ SparkPlanTest.checkAnswer(
+ input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -147,13 +150,14 @@ object SparkPlanTest {
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan)
+ executePlan(expectedOutputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -168,7 +172,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -207,12 +211,13 @@ object SparkPlanTest {
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
- sortAnswers: Boolean): Option[String] = {
+ sortAnswers: Boolean,
+ sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan)
+ executePlan(outputPlan, sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -275,10 +280,10 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ val resolvedPlan = sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2f79b0aad0..e6df64d264 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
def matchSerDe(clause: Seq[ASTNode])
- : (Seq[(String, String)], String, Seq[(String, String)]) = clause match {
+ : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match {
case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
val rowFormat = propsClause.map {
case Token(name, Token(value, Nil) :: Nil) => (name, value)
}
- (rowFormat, "", Nil)
+ (rowFormat, None, Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
- (Nil, serdeClass, Nil)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
Token("TOK_TABLEPROPERTIES",
@@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) =>
(name, value)
}
- (Nil, serdeClass, serdeProps)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
- case Nil => (Nil, "", Nil)
+ case Nil => (Nil, None, Nil)
}
val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 205e622195..741c705e2a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,15 +17,18 @@
package org.apache.spark.sql.hive.execution
-import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
+import java.io._
import java.util.Properties
+import javax.annotation.Nullable
import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
@@ -56,21 +59,53 @@ case class ScriptTransformation(
override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
protected override def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitions { iter =>
+ def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd)
- // We need to start threads connected to the process pipeline:
- // 1) The error msg generated by the script process would be hidden.
- // 2) If the error msg is too big to chock up the buffer, the input logic would be hung
+
val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
- val reader = new BufferedReader(new InputStreamReader(inputStream))
- val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
+ // In order to avoid deadlocks, we need to consume the error output of the child process.
+ // To avoid issues caused by large error output, we use a circular buffer to limit the amount
+ // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
+ // that motivates this.
+ val stderrBuffer = new CircularBuffer(2048)
+ new RedirectThread(
+ errorStream,
+ stderrBuffer,
+ "Thread-ScriptTransformation-STDERR-Consumer").start()
+
+ val outputProjection = new InterpretedProjection(input, child.output)
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
+
+ // This new thread will consume the ScriptTransformation's input rows and write them to the
+ // external process. That process's output will be read by this current thread.
+ val writerThread = new ScriptTransformationWriterThread(
+ inputIterator,
+ outputProjection,
+ inputSerde,
+ inputSoi,
+ ioschema,
+ outputStream,
+ proc,
+ stderrBuffer,
+ TaskContext.get()
+ )
+
+ // This nullability is a performance optimization in order to avoid an Option.foreach() call
+ // inside of a loop
+ @Nullable val (outputSerde, outputSoi) = {
+ ioschema.initOutputSerDe(output).getOrElse((null, null))
+ }
- val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
+ val reader = new BufferedReader(new InputStreamReader(inputStream))
+ val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var cacheRow: InternalRow = null
var curLine: String = null
var eof: Boolean = false
@@ -79,12 +114,26 @@ case class ScriptTransformation(
if (outputSerde == null) {
if (curLine == null) {
curLine = reader.readLine()
- curLine != null
+ if (curLine == null) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
} else {
true
}
} else {
- !eof
+ if (eof) {
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ } else {
+ true
+ }
}
}
@@ -110,11 +159,11 @@ case class ScriptTransformation(
}
i += 1
})
- return mutableRow
+ mutableRow
} catch {
case e: EOFException =>
eof = true
- return null
+ null
}
}
@@ -146,49 +195,83 @@ case class ScriptTransformation(
}
}
- val (inputSerde, inputSoi) = ioschema.initInputSerDe(input)
- val dataOutputStream = new DataOutputStream(outputStream)
- val outputProjection = new InterpretedProjection(input, child.output)
+ writerThread.start()
- // TODO make the 2048 configurable?
- val stderrBuffer = new CircularBuffer(2048)
- // Consume the error stream from the pipeline, otherwise it will be blocked if
- // the pipeline is full.
- new RedirectThread(errorStream, // input stream from the pipeline
- stderrBuffer, // output to a circular buffer
- "Thread-ScriptTransformation-STDERR-Consumer").start()
+ outputIterator
+ }
- // Put the write(output to the pipeline) into a single thread
- // and keep the collector as remain in the main thread.
- // otherwise it will causes deadlock if the data size greater than
- // the pipeline / buffer capacity.
- new Thread(new Runnable() {
- override def run(): Unit = {
- Utils.tryWithSafeFinally {
- iter
- .map(outputProjection)
- .foreach { row =>
- if (inputSerde == null) {
- val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
-
- outputStream.write(data)
- } else {
- val writable = inputSerde.serialize(
- row.asInstanceOf[GenericInternalRow].values, inputSoi)
- prepareWritable(writable).write(dataOutputStream)
- }
- }
- outputStream.close()
- } {
- if (proc.waitFor() != 0) {
- logError(stderrBuffer.toString) // log the stderr circular buffer
- }
- }
- }
- }, "Thread-ScriptTransformation-Feed").start()
+ child.execute().mapPartitions { iter =>
+ if (iter.hasNext) {
+ processIterator(iter)
+ } else {
+ // If the input iterator has no rows then do not launch the external script.
+ Iterator.empty
+ }
+ }
+ }
+}
- iterator
+private class ScriptTransformationWriterThread(
+ iter: Iterator[InternalRow],
+ outputProjection: Projection,
+ @Nullable inputSerde: AbstractSerDe,
+ @Nullable inputSoi: ObjectInspector,
+ ioschema: HiveScriptIOSchema,
+ outputStream: OutputStream,
+ proc: Process,
+ stderrBuffer: CircularBuffer,
+ taskContext: TaskContext
+ ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
+
+ setDaemon(true)
+
+ @volatile private var _exception: Throwable = null
+
+ /** Contains the exception thrown while writing the parent iterator to the external process. */
+ def exception: Option[Throwable] = Option(_exception)
+
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ TaskContext.setTaskContext(taskContext)
+
+ val dataOutputStream = new DataOutputStream(outputStream)
+
+ // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
+ // let's use a variable to record whether the `finally` block was hit due to an exception
+ var threwException: Boolean = true
+ try {
+ iter.map(outputProjection).foreach { row =>
+ if (inputSerde == null) {
+ val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+ outputStream.write(data)
+ } else {
+ val writable = inputSerde.serialize(
+ row.asInstanceOf[GenericInternalRow].values, inputSoi)
+ prepareWritable(writable).write(dataOutputStream)
+ }
+ }
+ outputStream.close()
+ threwException = false
+ } catch {
+ case NonFatal(e) =>
+ // An error occurred while writing input, so kill the child process. According to the
+ // Javadoc this call will not throw an exception:
+ _exception = e
+ proc.destroy()
+ throw e
+ } finally {
+ try {
+ if (proc.waitFor() != 0) {
+ logError(stderrBuffer.toString) // log the stderr circular buffer
+ }
+ } catch {
+ case NonFatal(exceptionFromFinallyBlock) =>
+ if (!threwException) {
+ throw exceptionFromFinallyBlock
+ } else {
+ log.error("Exception in finally block", exceptionFromFinallyBlock)
+ }
+ }
}
}
}
@@ -200,33 +283,43 @@ private[hive]
case class HiveScriptIOSchema (
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
- inputSerdeClass: String,
- outputSerdeClass: String,
+ inputSerdeClass: Option[String],
+ outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
- val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"),
- ("TOK_TABLEROWFORMATLINES", "\n"))
+ private val defaultFormat = Map(
+ ("TOK_TABLEROWFORMATFIELD", "\t"),
+ ("TOK_TABLEROWFORMATLINES", "\n")
+ )
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
- def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(input)
- val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
- (serde, initInputSoi(serde, columns, columnTypes))
+ def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
+ inputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(input)
+ val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
+ val fieldObjectInspectors = columnTypes.map(toInspector)
+ val objectInspector = ObjectInspectorFactory
+ .getStandardStructObjectInspector(columns, fieldObjectInspectors)
+ .asInstanceOf[ObjectInspector]
+ (serde, objectInspector)
+ }
}
- def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = {
- val (columns, columnTypes) = parseAttrs(output)
- val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps)
- (serde, initOutputputSoi(serde))
+ def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
+ outputSerdeClass.map { serdeClass =>
+ val (columns, columnTypes) = parseAttrs(output)
+ val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
+ val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
+ (serde, structObjectInspector)
+ }
}
- def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+ private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
val columns = attrs.map {
case aref: AttributeReference => aref.name
case e: NamedExpression => e.name
@@ -242,52 +335,25 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}
- def initSerDe(serdeClassName: String, columns: Seq[String],
- columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
+ private def initSerDe(
+ serdeClassName: String,
+ columns: Seq[String],
+ columnTypes: Seq[DataType],
+ serdeProps: Seq[(String, String)]): AbstractSerDe = {
- val serde: AbstractSerDe = if (serdeClassName != "") {
- val trimed_class = serdeClassName.split("'")(1)
- Utils.classForName(trimed_class)
- .newInstance.asInstanceOf[AbstractSerDe]
- } else {
- null
- }
+ val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
- if (serde != null) {
- val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
+ val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
- var propsMap = serdeProps.map(kv => {
- (kv._1.split("'")(1), kv._2.split("'")(1))
- }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
- propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
+ var propsMap = serdeProps.map(kv => {
+ (kv._1.split("'")(1), kv._2.split("'")(1))
+ }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
+ propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
- val properties = new Properties()
- properties.putAll(propsMap)
- serde.initialize(null, properties)
- }
+ val properties = new Properties()
+ properties.putAll(propsMap)
+ serde.initialize(null, properties)
serde
}
-
- def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType])
- : ObjectInspector = {
-
- if (inputSerde != null) {
- val fieldObjectInspectors = columnTypes.map(toInspector(_))
- ObjectInspectorFactory
- .getStandardStructObjectInspector(columns, fieldObjectInspectors)
- .asInstanceOf[ObjectInspector]
- } else {
- null
- }
- }
-
- def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
- if (outputSerde != null) {
- outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
- } else {
- null
- }
- }
}
-
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
new file mode 100644
index 0000000000..0875232aed
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.scalatest.exceptions.TestFailedException
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.types.StringType
+
+class ScriptTransformationSuite extends SparkPlanTest {
+
+ override def sqlContext: SQLContext = TestHive
+
+ private val noSerdeIOSchema = HiveScriptIOSchema(
+ inputRowFormat = Seq.empty,
+ outputRowFormat = Seq.empty,
+ inputSerdeClass = None,
+ outputSerdeClass = None,
+ inputSerdeProps = Seq.empty,
+ outputSerdeProps = Seq.empty,
+ schemaLess = false
+ )
+
+ private val serdeIOSchema = noSerdeIOSchema.copy(
+ inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
+ outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
+ )
+
+ test("cat without SerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("cat with LazySimpleSerDe") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = child,
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+
+ test("script transformation should not swallow errors from upstream operators (no serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = noSerdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+
+ test("script transformation should not swallow errors from upstream operators (with serde)") {
+ val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+ val e = intercept[TestFailedException] {
+ checkAnswer(
+ rowsDf,
+ (child: SparkPlan) => new ScriptTransformation(
+ input = Seq(rowsDf.col("a").expr),
+ script = "cat",
+ output = Seq(AttributeReference("a", StringType)()),
+ child = ExceptionInjectingOperator(child),
+ ioschema = serdeIOSchema
+ )(TestHive),
+ rowsDf.collect())
+ }
+ assert(e.getMessage().contains("intentional exception"))
+ }
+}
+
+private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode {
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().map { x =>
+ assert(TaskContext.get() != null) // Make sure that TaskContext is defined.
+ Thread.sleep(1000) // This sleep gives the external process time to start.
+ throw new IllegalArgumentException("intentional exception")
+ }
+ }
+ override def output: Seq[Attribute] = child.output
+}