diff options
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index cec97de2cd..9552f41115 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -50,10 +50,10 @@ private[sql] class Serializer2SerializationStream( extends SerializationStream with Logging { val rowOut = new DataOutputStream(out) - val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) - val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { val kv = t.asInstanceOf[Product2[Row, Row]] writeKey(kv._1) writeValue(kv._2) @@ -61,6 +61,16 @@ private[sql] class Serializer2SerializationStream( this } + override def writeKey[T: ClassTag](t: T): SerializationStream = { + writeKeyFunc(t.asInstanceOf[Row]) + this + } + + override def writeValue[T: ClassTag](t: T): SerializationStream = { + writeValueFunc(t.asInstanceOf[Row]) + this + } + def flush(): Unit = { rowOut.flush() } @@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream( val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null - val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) - val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) - def readObject[T: ClassTag](): T = { - readKey() - readValue() + override def readObject[T: ClassTag](): T = { + readKeyFunc() + readValueFunc() (key, value).asInstanceOf[T] } - def close(): Unit = { + override def readKey[T: ClassTag](): T = { + readKeyFunc() + key.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + readValueFunc() + value.asInstanceOf[T] + } + + override def close(): Unit = { rowIn.close() } } |