aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-09 14:43:38 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-09 14:43:38 -0700
commitc9e2ef52bb54f35a904427389dc492d61f29b018 (patch)
tree90887ae7055aa78751561119083bd09ac099e0f4 /sql
parent3ccebf36c5abe04702d4cf223552a94034d980fb (diff)
downloadspark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.gz
spark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.bz2
spark-c9e2ef52bb54f35a904427389dc492d61f29b018.zip
[SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame
This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu <davies@databricks.com> Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala122
3 files changed, 118 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 8b472a529e..094904bbf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -132,6 +132,18 @@ class GenericInternalRow(protected[sql] val values: Array[Any])
override def copy(): InternalRow = this
}
+/**
+ * This is used for serialization of Python DataFrame
+ */
+class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType)
+ extends GenericInternalRow(values) {
+
+ /** No-arg constructor for serialization. */
+ protected def this() = this(null, null)
+
+ override def fieldIndex(name: String): Int = schema.fieldIndex(name)
+}
+
class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow {
/** No-arg constructor for serialization. */
protected def this() = this(null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index d9f987ae02..d7966651b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -30,7 +30,6 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
@@ -1550,8 +1549,8 @@ class DataFrame private[sql](
*/
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val structType = schema // capture it for closure
- val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
+ val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType))
+ EvaluatePython.javaToPython(rdd)
}
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 1c8130b07c..6d6e67dace 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -17,15 +17,16 @@
package org.apache.spark.sql.execution
+import java.io.OutputStream
import java.util.{List => JList, Map => JMap}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
-import net.razorvine.pickle.{Pickler, Unpickler}
+import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
+import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.{Accumulator, Logging => SparkLogging}
@@ -130,8 +130,13 @@ object EvaluatePython {
case (null, _) => null
case (row: InternalRow, struct: StructType) =>
- val fields = struct.fields.map(field => field.dataType)
- rowToArray(row, fields)
+ val values = new Array[Any](row.size)
+ var i = 0
+ while (i < row.size) {
+ values(i) = toJava(row(i), struct.fields(i).dataType)
+ i += 1
+ }
+ new GenericInternalRowWithSchema(values, struct)
case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
@@ -142,9 +147,6 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
- case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
- case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
-
case (d: Decimal, _) => d.toJavaBigDecimal
case (s: UTF8String, StringType) => s.toString
@@ -153,14 +155,6 @@ object EvaluatePython {
}
/**
- * Convert Row into Java Array (for pickled into Python)
- */
- def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = {
- // TODO: this is slow!
- row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
- }
-
- /**
* Converts `obj` to the type specified by the data type, or returns null if the type of obj is
* unexpected. Because Python doesn't enforce the type.
*/
@@ -220,6 +214,96 @@ object EvaluatePython {
// TODO(davies): we could improve this by try to cast the object to expected type
case (c, _) => null
}
+
+
+ private val module = "pyspark.sql.types"
+
+ /**
+ * Pickler for StructType
+ */
+ private class StructTypePickler extends IObjectPickler {
+
+ private val cls = classOf[StructType]
+
+ def register(): Unit = {
+ Pickler.registerCustomPickler(cls, this)
+ }
+
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ out.write(Opcodes.GLOBAL)
+ out.write((module + "\n" + "_parse_datatype_json_string" + "\n").getBytes("utf-8"))
+ val schema = obj.asInstanceOf[StructType]
+ pickler.save(schema.json)
+ out.write(Opcodes.TUPLE1)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+
+ /**
+ * Pickler for InternalRow
+ */
+ private class RowPickler extends IObjectPickler {
+
+ private val cls = classOf[GenericInternalRowWithSchema]
+
+ // register this to Pickler and Unpickler
+ def register(): Unit = {
+ Pickler.registerCustomPickler(this.getClass, this)
+ Pickler.registerCustomPickler(cls, this)
+ }
+
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ if (obj == this) {
+ out.write(Opcodes.GLOBAL)
+ out.write((module + "\n" + "_create_row_inbound_converter" + "\n").getBytes("utf-8"))
+ } else {
+ // it will be memorized by Pickler to save some bytes
+ pickler.save(this)
+ val row = obj.asInstanceOf[GenericInternalRowWithSchema]
+ // schema should always be same object for memoization
+ pickler.save(row.schema)
+ out.write(Opcodes.TUPLE1)
+ out.write(Opcodes.REDUCE)
+
+ out.write(Opcodes.MARK)
+ var i = 0
+ while (i < row.values.size) {
+ pickler.save(row.values(i))
+ i += 1
+ }
+ row.values.foreach(pickler.save)
+ out.write(Opcodes.TUPLE)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+ }
+
+ private[this] var registered = false
+ /**
+ * This should be called before trying to serialize any above classes un cluster mode,
+ * this should be put in the closure
+ */
+ def registerPicklers(): Unit = {
+ synchronized {
+ if (!registered) {
+ SerDeUtil.initialize()
+ new StructTypePickler().register()
+ new RowPickler().register()
+ registered = true
+ }
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = {
+ rdd.mapPartitions { iter =>
+ registerPicklers() // let it called in executor
+ new SerDeUtil.AutoBatchedPickler(iter)
+ }
+ }
}
/**
@@ -254,12 +338,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val childResults = child.execute().map(_.copy())
val parent = childResults.mapPartitions { iter =>
+ EvaluatePython.registerPicklers() // register pickler for Row
val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
- iter.grouped(1000).map { inputRows =>
+ val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+ iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
- EvaluatePython.rowToArray(currentRow(row), fields)
+ EvaluatePython.toJava(currentRow(row), schema)
}.toArray
pickle.dumps(toBePickled)
}