aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala38
1 files changed, 29 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index cec97de2cd..9552f41115 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -50,10 +50,10 @@ private[sql] class Serializer2SerializationStream(
extends SerializationStream with Logging {
val rowOut = new DataOutputStream(out)
- val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
- val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+ val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
+ val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
- def writeObject[T: ClassTag](t: T): SerializationStream = {
+ override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
writeKey(kv._1)
writeValue(kv._2)
@@ -61,6 +61,16 @@ private[sql] class Serializer2SerializationStream(
this
}
+ override def writeKey[T: ClassTag](t: T): SerializationStream = {
+ writeKeyFunc(t.asInstanceOf[Row])
+ this
+ }
+
+ override def writeValue[T: ClassTag](t: T): SerializationStream = {
+ writeValueFunc(t.asInstanceOf[Row])
+ this
+ }
+
def flush(): Unit = {
rowOut.flush()
}
@@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream(
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
- val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
- val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+ val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
+ val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
- def readObject[T: ClassTag](): T = {
- readKey()
- readValue()
+ override def readObject[T: ClassTag](): T = {
+ readKeyFunc()
+ readValueFunc()
(key, value).asInstanceOf[T]
}
- def close(): Unit = {
+ override def readKey[T: ClassTag](): T = {
+ readKeyFunc()
+ key.asInstanceOf[T]
+ }
+
+ override def readValue[T: ClassTag](): T = {
+ readValueFunc()
+ value.asInstanceOf[T]
+ }
+
+ override def close(): Unit = {
rowIn.close()
}
}