diff options
author | Davies Liu <davies@databricks.com> | 2015-07-09 14:43:38 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-09 14:43:38 -0700 |
commit | c9e2ef52bb54f35a904427389dc492d61f29b018 (patch) | |
tree | 90887ae7055aa78751561119083bd09ac099e0f4 /sql | |
parent | 3ccebf36c5abe04702d4cf223552a94034d980fb (diff) | |
download | spark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.gz spark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.bz2 spark-c9e2ef52bb54f35a904427389dc492d61f29b018.zip |
[SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame
This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr .
There is no generated `Row` anymore.
Author: Davies Liu <davies@databricks.com>
Closes #7301 from davies/sql_ser and squashes the following commits:
81bef71 [Davies Liu] address comments
e9217bd [Davies Liu] add regression tests
db34167 [Davies Liu] Refactor of serialization for Python DataFrame
Diffstat (limited to 'sql')
3 files changed, 118 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8b472a529e..094904bbf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -132,6 +132,18 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) override def copy(): InternalRow = this } +/** + * This is used for serialization of Python DataFrame + */ +class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) + extends GenericInternalRow(values) { + + /** No-arg constructor for serialization. */ + protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) +} + class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d9f987ae02..d7966651b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,6 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -1550,8 +1549,8 @@ class DataFrame private[sql]( */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure - val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) + val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) + EvaluatePython.javaToPython(rdd) } //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 1c8130b07c..6d6e67dace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.execution +import java.io.OutputStream import java.util.{List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import net.razorvine.pickle.{Pickler, Unpickler} +import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Accumulator, Logging => SparkLogging} @@ -130,8 +130,13 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val fields = struct.fields.map(field => field.dataType) - rowToArray(row, fields) + val values = new Array[Any](row.size) + var i = 0 + while (i < row.size) { + values(i) = toJava(row(i), struct.fields(i).dataType) + i += 1 + } + new GenericInternalRowWithSchema(values, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -142,9 +147,6 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) - case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) - case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) - case (d: Decimal, _) => d.toJavaBigDecimal case (s: UTF8String, StringType) => s.toString @@ -153,14 +155,6 @@ object EvaluatePython { } /** - * Convert Row into Java Array (for pickled into Python) - */ - def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { - // TODO: this is slow! - row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray - } - - /** * Converts `obj` to the type specified by the data type, or returns null if the type of obj is * unexpected. Because Python doesn't enforce the type. */ @@ -220,6 +214,96 @@ object EvaluatePython { // TODO(davies): we could improve this by try to cast the object to expected type case (c, _) => null } + + + private val module = "pyspark.sql.types" + + /** + * Pickler for StructType + */ + private class StructTypePickler extends IObjectPickler { + + private val cls = classOf[StructType] + + def register(): Unit = { + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8")) + val schema = obj.asInstanceOf[StructType] + pickler.save(schema.json) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + } + } + + /** + * Pickler for InternalRow + */ + private class RowPickler extends IObjectPickler { + + private val cls = classOf[GenericInternalRowWithSchema] + + // register this to Pickler and Unpickler + def register(): Unit = { + Pickler.registerCustomPickler(this.getClass, this) + Pickler.registerCustomPickler(cls, this) + } + + def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + if (obj == this) { + out.write(Opcodes.GLOBAL) + out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8")) + } else { + // it will be memorized by Pickler to save some bytes + pickler.save(this) + val row = obj.asInstanceOf[GenericInternalRowWithSchema] + // schema should always be same object for memoization + pickler.save(row.schema) + out.write(Opcodes.TUPLE1) + out.write(Opcodes.REDUCE) + + out.write(Opcodes.MARK) + var i = 0 + while (i < row.values.size) { + pickler.save(row.values(i)) + i += 1 + } + row.values.foreach(pickler.save) + out.write(Opcodes.TUPLE) + out.write(Opcodes.REDUCE) + } + } + } + + private[this] var registered = false + /** + * This should be called before trying to serialize any above classes un cluster mode, + * this should be put in the closure + */ + def registerPicklers(): Unit = { + synchronized { + if (!registered) { + SerDeUtil.initialize() + new StructTypePickler().register() + new RowPickler().register() + registered = true + } + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { + rdd.mapPartitions { iter => + registerPicklers() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } } /** @@ -254,12 +338,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val childResults = child.execute().map(_.copy()) val parent = childResults.mapPartitions { iter => + EvaluatePython.registerPicklers() // register pickler for Row val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) - iter.grouped(1000).map { inputRows => + val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => - EvaluatePython.rowToArray(currentRow(row), fields) + EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } |