aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-08-04 18:26:05 -0700
committerMichael Armbrust <michael@databricks.com>2015-08-04 18:26:05 -0700
commit6f8f0e265a29e89bd5192a8d5217cba19f0875da (patch)
treed8c83ca4ac4620ca0e6344e97f870b8cbb15b505 /sql
parentc9a4c36d052456c2dd1f7e0a871c6b764b5064d2 (diff)
downloadspark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.tar.gz
spark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.tar.bz2
spark-6f8f0e265a29e89bd5192a8d5217cba19f0875da.zip
[SPARK-7119] [SQL] Give script a default serde with the user specific types
This is to address this issue that there would be not compatible type exception when running this: `from (from src select transform(key, value) using 'cat' as (thing1 int, thing2 string)) t select thing1 + 2;` 15/04/24 00:58:55 ERROR CliDriver: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ClassCastException: org.apache.spark.sql.types.UTF8String cannot be cast to java.lang.Integer at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106) at scala.math.Numeric$IntIsIntegral$.plus(Numeric.scala:57) at org.apache.spark.sql.catalyst.expressions.Add.eval(arithmetic.scala:127) at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:118) at org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:68) at org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:52) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819) at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:63) at org.apache.spark.scheduler.Task.run(Task.scala:64) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:209) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1110) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:603) at java.lang.Thread.run(Thread.java:722) chenghao-intel marmbrus Author: zhichao.li <zhichao.li@intel.com> Closes #6638 from zhichao-li/transDataType2 and squashes the following commits: a36cc7c [zhichao.li] style b9252a8 [zhichao.li] delete cacheRow f6968a4 [zhichao.li] give script a default serde
Diffstat (limited to 'sql')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala96
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala10
3 files changed, 49 insertions, 60 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index e2fdfc6163..f43e403ce9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -21,6 +21,7 @@ import java.sql.Date
import java.util.Locale
import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
@@ -907,7 +908,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
}
(Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
- case Nil => (Nil, None, Nil)
+ case Nil => (Nil, Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
}
val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index fbb86406f4..97e4ea2081 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -27,11 +27,11 @@ import scala.util.control.NonFatal
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.io.Writable
import org.apache.spark.{TaskContext, Logging}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.execution._
@@ -106,9 +106,15 @@ case class ScriptTransformation(
val reader = new BufferedReader(new InputStreamReader(inputStream))
val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
- var cacheRow: InternalRow = null
var curLine: String = null
- var eof: Boolean = false
+ val scriptOutputStream = new DataInputStream(inputStream)
+ var scriptOutputWritable: Writable = null
+ val reusedWritableObject: Writable = if (null != outputSerde) {
+ outputSerde.getSerializedClass().newInstance
+ } else {
+ null
+ }
+ val mutableRow = new SpecificMutableRow(output.map(_.dataType))
override def hasNext: Boolean = {
if (outputSerde == null) {
@@ -125,45 +131,20 @@ case class ScriptTransformation(
} else {
true
}
- } else {
- if (eof) {
- if (writerThread.exception.isDefined) {
- throw writerThread.exception.get
- }
- false
- } else {
+ } else if (scriptOutputWritable == null) {
+ scriptOutputWritable = reusedWritableObject
+ try {
+ scriptOutputWritable.readFields(scriptOutputStream)
true
+ } catch {
+ case _: EOFException =>
+ if (writerThread.exception.isDefined) {
+ throw writerThread.exception.get
+ }
+ false
}
- }
- }
-
- def deserialize(): InternalRow = {
- if (cacheRow != null) return cacheRow
-
- val mutableRow = new SpecificMutableRow(output.map(_.dataType))
- try {
- val dataInputStream = new DataInputStream(inputStream)
- val writable = outputSerde.getSerializedClass().newInstance
- writable.readFields(dataInputStream)
-
- val raw = outputSerde.deserialize(writable)
- val dataList = outputSoi.getStructFieldsDataAsList(raw)
- val fieldList = outputSoi.getAllStructFieldRefs()
-
- var i = 0
- dataList.foreach( element => {
- if (element == null) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow(i) = unwrap(element, fieldList(i).getFieldObjectInspector)
- }
- i += 1
- })
- mutableRow
- } catch {
- case e: EOFException =>
- eof = true
- null
+ } else {
+ true
}
}
@@ -171,7 +152,6 @@ case class ScriptTransformation(
if (!hasNext) {
throw new NoSuchElementException
}
-
if (outputSerde == null) {
val prevLine = curLine
curLine = reader.readLine()
@@ -185,12 +165,20 @@ case class ScriptTransformation(
.map(CatalystTypeConverters.convertToCatalyst))
}
} else {
- val ret = deserialize()
- if (!eof) {
- cacheRow = null
- cacheRow = deserialize()
+ val raw = outputSerde.deserialize(scriptOutputWritable)
+ scriptOutputWritable = null
+ val dataList = outputSoi.getStructFieldsDataAsList(raw)
+ val fieldList = outputSoi.getAllStructFieldRefs()
+ var i = 0
+ while (i < dataList.size()) {
+ if (dataList(i) == null) {
+ mutableRow.setNullAt(i)
+ } else {
+ mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector)
+ }
+ i += 1
}
- ret
+ mutableRow
}
}
}
@@ -320,18 +308,8 @@ case class HiveScriptIOSchema (
}
private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
- val columns = attrs.map {
- case aref: AttributeReference => aref.name
- case e: NamedExpression => e.name
- case _ => null
- }
-
- val columnTypes = attrs.map {
- case aref: AttributeReference => aref.dataType
- case e: NamedExpression => e.dataType
- case _ => null
- }
-
+ val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}")
+ val columnTypes = attrs.map(_.dataType)
(columns, columnTypes)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index fb41451803..ff9a3694d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -751,6 +751,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
.queryExecution.toRdd.count())
}
+ test("test script transform data type") {
+ val data = (1 to 5).map { i => (i, i) }
+ data.toDF("key", "value").registerTempTable("test")
+ checkAnswer(
+ sql("""FROM
+ |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t
+ |SELECT thing1 + 1
+ """.stripMargin), (2 to 6).map(i => Row(i)))
+ }
+
test("window function: udaf with aggregate expressin") {
val data = Seq(
WindowData(1, "a", 5),