aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZhichao Li <zhichao.li@intel.com>2015-09-22 19:41:57 -0700
committerYin Huai <yhuai@databricks.com>2015-09-22 19:41:57 -0700
commit84f81e035e1dab1b42c36563041df6ba16e7b287 (patch)
tree36d06cf10253cc10a201bec6d2e26d7b44862e5e
parent61d4c07f4becb42f054e588be56ed13239644410 (diff)
downloadspark-84f81e035e1dab1b42c36563041df6ba16e7b287.tar.gz
spark-84f81e035e1dab1b42c36563041df6ba16e7b287.tar.bz2
spark-84f81e035e1dab1b42c36563041df6ba16e7b287.zip
[SPARK-10310] [SQL] Fixes script transformation field/line delimiters
**Please attribute this PR to `Zhichao Li <zhichao.liintel.com>`.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian <lian@databricks.com> Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters.
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala52
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala75
-rwxr-xr-xsql/hive/src/test/resources/data/scripts/test_transform.py6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala39
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala2
5 files changed, 152 insertions, 22 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index d5cd7e98b5..256440a9a2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
@@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
AttributeReference("value", StringType)()), true)
}
- def matchSerDe(clause: Seq[ASTNode])
- : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match {
+ type SerDeInfo = (
+ Seq[(String, String)], // Input row format information
+ Option[String], // Optional input SerDe class
+ Seq[(String, String)], // Input SerDe properties
+ Boolean // Whether to use default record reader/writer
+ )
+
+ def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match {
case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
val rowFormat = propsClause.map {
case Token(name, Token(value, Nil) :: Nil) => (name, value)
}
- (rowFormat, None, Nil)
+ (rowFormat, None, Nil, false)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
- (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
+ (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false)
case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
Token("TOK_TABLEPROPERTIES",
@@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
(BaseSemanticAnalyzer.unescapeSQLString(name),
BaseSemanticAnalyzer.unescapeSQLString(value))
}
- (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
- case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
+ // SPARK-10310: Special cases LazySimpleSerDe
+ // TODO Fully supports user-defined record reader/writer classes
+ val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass)
+ val useDefaultRecordReaderWriter =
+ unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName
+ (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter)
+
+ case Nil =>
+ // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here
+ val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t")
+ (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true)
}
- val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
- val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause)
+ val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) =
+ matchSerDe(inputSerdeClause)
+
+ val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) =
+ matchSerDe(outputSerdeClause)
val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)
+ // TODO Adds support for user-defined record reader/writer classes
+ val recordReaderClass = if (useDefaultRecordReader) {
+ Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER))
+ } else {
+ None
+ }
+
+ val recordWriterClass = if (useDefaultRecordWriter) {
+ Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER))
+ } else {
+ None
+ }
+
val schema = HiveScriptIOSchema(
inRowFormat, outRowFormat,
inSerdeClass, outSerdeClass,
- inSerdeProps, outSerdeProps, schemaLess)
+ inSerdeProps, outSerdeProps,
+ recordReaderClass, recordWriterClass,
+ schemaLess)
Some(
logical.ScriptTransformation(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 32bddbaeae..b30117f0de 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -24,20 +24,22 @@ import javax.annotation.Nullable
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.io.Writable
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
import org.apache.spark.{Logging, TaskContext}
/**
@@ -58,6 +60,8 @@ case class ScriptTransformation(
override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+ private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
+
protected override def doExecute(): RDD[InternalRow] = {
def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
@@ -67,6 +71,7 @@ case class ScriptTransformation(
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
+ val localHiveConf = serializedHiveConf.value
// In order to avoid deadlocks, we need to consume the error output of the child process.
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
@@ -96,7 +101,8 @@ case class ScriptTransformation(
outputStream,
proc,
stderrBuffer,
- TaskContext.get()
+ TaskContext.get(),
+ localHiveConf
)
// This nullability is a performance optimization in order to avoid an Option.foreach() call
@@ -109,6 +115,10 @@ case class ScriptTransformation(
val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var curLine: String = null
val scriptOutputStream = new DataInputStream(inputStream)
+
+ @Nullable val scriptOutputReader =
+ ioschema.recordReader(scriptOutputStream, localHiveConf).orNull
+
var scriptOutputWritable: Writable = null
val reusedWritableObject: Writable = if (null != outputSerde) {
outputSerde.getSerializedClass().newInstance
@@ -134,15 +144,25 @@ case class ScriptTransformation(
}
} else if (scriptOutputWritable == null) {
scriptOutputWritable = reusedWritableObject
- try {
- scriptOutputWritable.readFields(scriptOutputStream)
- true
- } catch {
- case _: EOFException =>
- if (writerThread.exception.isDefined) {
- throw writerThread.exception.get
- }
+
+ if (scriptOutputReader != null) {
+ if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
+ writerThread.exception.foreach(throw _)
false
+ } else {
+ true
+ }
+ } else {
+ try {
+ scriptOutputWritable.readFields(scriptOutputStream)
+ true
+ } catch {
+ case _: EOFException =>
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
+ }
}
} else {
true
@@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread(
outputStream: OutputStream,
proc: Process,
stderrBuffer: CircularBuffer,
- taskContext: TaskContext
+ taskContext: TaskContext,
+ conf: Configuration
) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
setDaemon(true)
@@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread(
TaskContext.setTaskContext(taskContext)
val dataOutputStream = new DataOutputStream(outputStream)
+ @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull
// We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
// let's use a variable to record whether the `finally` block was hit due to an exception
@@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread(
} else {
val writable = inputSerde.serialize(
row.asInstanceOf[GenericInternalRow].values, inputSoi)
- prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
+
+ if (scriptInputWriter != null) {
+ scriptInputWriter.write(writable)
+ } else {
+ prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
+ }
}
}
outputStream.close()
@@ -290,6 +317,8 @@ case class HiveScriptIOSchema (
outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
+ recordReaderClass: Option[String],
+ recordWriterClass: Option[String],
schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
private val defaultFormat = Map(
@@ -347,4 +376,24 @@ case class HiveScriptIOSchema (
serde
}
+
+ def recordReader(
+ inputStream: InputStream,
+ conf: Configuration): Option[RecordReader] = {
+ recordReaderClass.map { klass =>
+ val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
+ val props = new Properties()
+ props.putAll(outputSerdeProps.toMap.asJava)
+ instance.initialize(inputStream, conf, props)
+ instance
+ }
+ }
+
+ def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = {
+ recordWriterClass.map { klass =>
+ val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
+ instance.initialize(outputStream, conf)
+ instance
+ }
+ }
}
diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py
new file mode 100755
index 0000000000..ac6d11d8b9
--- /dev/null
+++ b/sql/hive/src/test/resources/data/scripts/test_transform.py
@@ -0,0 +1,6 @@
+import sys
+
+delim = sys.argv[1]
+
+for row in sys.stdin:
+ print(delim.join([w + '#' for w in row[:-1].split(delim)]))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index bb02473dd1..71823e32ad 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(df, Row("text inside layer 2") :: Nil)
}
+
+ test("SPARK-10310: " +
+ "script transformation using default input/output SerDe and record reader/writer") {
+ sqlContext
+ .range(5)
+ .selectExpr("id AS a", "id AS b")
+ .registerTempTable("test")
+
+ checkAnswer(
+ sql(
+ """FROM(
+ | FROM test SELECT TRANSFORM(a, b)
+ | USING 'python src/test/resources/data/scripts/test_transform.py "\t"'
+ | AS (c STRING, d STRING)
+ |) t
+ |SELECT c
+ """.stripMargin),
+ (0 until 5).map(i => Row(i + "#")))
+ }
+
+ test("SPARK-10310: script transformation using LazySimpleSerDe") {
+ sqlContext
+ .range(5)
+ .selectExpr("id AS a", "id AS b")
+ .registerTempTable("test")
+
+ val df = sql(
+ """FROM test
+ |SELECT TRANSFORM(a, b)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+ |WITH SERDEPROPERTIES('field.delim' = '|')
+ |USING 'python src/test/resources/data/scripts/test_transform.py "|"'
+ |AS (c STRING, d STRING)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+ |WITH SERDEPROPERTIES('field.delim' = '|')
+ """.stripMargin)
+
+ checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#")))
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index cb8d0fca8e..7cfdb886b5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton {
outputSerdeClass = None,
inputSerdeProps = Seq.empty,
outputSerdeProps = Seq.empty,
+ recordReaderClass = None,
+ recordWriterClass = None,
schemaLess = false
)