diff options
author | zhichao.li <zhichao.li@intel.com> | 2015-08-04 18:26:05 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-08-04 18:26:05 -0700 |
commit | 6f8f0e265a29e89bd5192a8d5217cba19f0875da (patch) | |
tree | d8c83ca4ac4620ca0e6344e97f870b8cbb15b505 /sql | |
parent | c9a4c36d052456c2dd1f7e0a871c6b764b5064d2 (diff) | |
download | spark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.tar.gz spark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.tar.bz2 spark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.zip |
[SPARK-7119] [SQL] Give script a default serde with the user specific types
This is to address this issue that there would be not compatible type exception when running this:
`from (from src select transform(key, value) using 'cat' as (thing1 int, thing2 string)) t select thing1 + 2;`
15/04/24 00:58:55 ERROR CliDriver: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ClassCastException: org.apache.spark.sql.types.UTF8String cannot be cast to java.lang.Integer
at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106)
at scala.math.Numeric$IntIsIntegral$.plus(Numeric.scala:57)
at org.apache.spark.sql.catalyst.expressions.Add.eval(arithmetic.scala:127)
at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:118)
at org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:68)
at org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:52)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$class.foreach(Iterator.scala:727)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
at scala.collection.AbstractIterator.to(Iterator.scala:1157)
at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819)
at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:63)
at org.apache.spark.scheduler.Task.run(Task.scala:64)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:209)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1110)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:603)
at java.lang.Thread.run(Thread.java:722)
chenghao-intel marmbrus
Author: zhichao.li <zhichao.li@intel.com>
Closes #6638 from zhichao-li/transDataType2 and squashes the following commits:
a36cc7c [zhichao.li] style
b9252a8 [zhichao.li] delete cacheRow
f6968a4 [zhichao.li] give script a default serde
Diffstat (limited to 'sql')
3 files changed, 49 insertions, 60 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e2fdfc6163..f43e403ce9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.Locale import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.ql.{ErrorMsg, Context} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} @@ -907,7 +908,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, None, Nil) + case Nil => (Nil, Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index fbb86406f4..97e4ea2081 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -27,11 +27,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.io.Writable import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ @@ -106,9 +106,15 @@ case class ScriptTransformation( val reader = new BufferedReader(new InputStreamReader(inputStream)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { - var cacheRow: InternalRow = null var curLine: String = null - var eof: Boolean = false + val scriptOutputStream = new DataInputStream(inputStream) + var scriptOutputWritable: Writable = null + val reusedWritableObject: Writable = if (null != outputSerde) { + outputSerde.getSerializedClass().newInstance + } else { + null + } + val mutableRow = new SpecificMutableRow(output.map(_.dataType)) override def hasNext: Boolean = { if (outputSerde == null) { @@ -125,45 +131,20 @@ case class ScriptTransformation( } else { true } - } else { - if (eof) { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - false - } else { + } else if (scriptOutputWritable == null) { + scriptOutputWritable = reusedWritableObject + try { + scriptOutputWritable.readFields(scriptOutputStream) true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false } - } - } - - def deserialize(): InternalRow = { - if (cacheRow != null) return cacheRow - - val mutableRow = new SpecificMutableRow(output.map(_.dataType)) - try { - val dataInputStream = new DataInputStream(inputStream) - val writable = outputSerde.getSerializedClass().newInstance - writable.readFields(dataInputStream) - - val raw = outputSerde.deserialize(writable) - val dataList = outputSoi.getStructFieldsDataAsList(raw) - val fieldList = outputSoi.getAllStructFieldRefs() - - var i = 0 - dataList.foreach( element => { - if (element == null) { - mutableRow.setNullAt(i) - } else { - mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector) - } - i += 1 - }) - mutableRow - } catch { - case e: EOFException => - eof = true - null + } else { + true } } @@ -171,7 +152,6 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -185,12 +165,20 @@ case class ScriptTransformation( .map(CatalystTypeConverters.convertToCatalyst)) } } else { - val ret = deserialize() - if (!eof) { - cacheRow = null - cacheRow = deserialize() + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + val fieldList = outputSoi.getAllStructFieldRefs() + var i = 0 + while (i < dataList.size()) { + if (dataList(i) == null) { + mutableRow.setNullAt(i) + } else { + mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector) + } + i += 1 } - ret + mutableRow } } } @@ -320,18 +308,8 @@ case class HiveScriptIOSchema ( } private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.map { - case aref: AttributeReference => aref.name - case e: NamedExpression => e.name - case _ => null - } - - val columnTypes = attrs.map { - case aref: AttributeReference => aref.dataType - case e: NamedExpression => e.dataType - case _ => null - } - + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) (columns, columnTypes) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index fb41451803..ff9a3694d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -751,6 +751,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { .queryExecution.toRdd.count()) } + test("test script transform data type") { + val data = (1 to 5).map { i => (i, i) } + data.toDF("key", "value").registerTempTable("test") + checkAnswer( + sql("""FROM + |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t + |SELECT thing1 + 1 + """.stripMargin), (2 to 6).map(i => Row(i))) + } + test("window function: udaf with aggregate expressin") { val data = Seq( WindowData(1, "a", 5), |