diff options
author | Davies Liu <davies@databricks.com> | 2016-03-31 16:40:20 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-31 16:40:20 -0700 |
commit | f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch) | |
tree | b374ad4a7c98e11f8f85fbd44618422bd4fe6a1b /sql | |
parent | 8de201baedc8e839e06098c536ba31b3dafd54b5 (diff) | |
download | spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2 spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip |
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request?
This PR support multiple Python UDFs within single batch, also improve the performance.
```python
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType())
>>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType())
>>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)]
+- OneRowRelation$
== Analyzed Logical Plan ==
double(add(1, 2)): int, add(double(2), 1): int
Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [double(add(1, 2))#14,add(double(2), 1)#15]
+- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Optimized Logical Plan ==
Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
+- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18]
+- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- OneRowRelation$
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15]
: +- INPUT
+- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18]
+- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17]
+- Scan OneRowRelation[]
```
## How was this patch tested?
Added new tests.
Using the following script to benchmark 1, 2 and 3 udfs,
```
df = sqlContext.range(1, 1 << 23, 1, 4)
double = F.udf(lambda x: x * 2, LongType())
print df.select(double(df.id)).count()
print df.select(double(df.id), double(df.id + 1)).count()
print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count()
```
Here is the results:
N | Before | After | speed up
---- |------------ | -------------|------
1 | 22 s | 7 s | 3.1X
2 | 38 s | 13 s | 2.9X
3 | 58 s | 16 s | 3.6X
This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering).
Author: Davies Liu <davies@databricks.com>
Closes #12057 from davies/multi_udfs.
Diffstat (limited to 'sql')
4 files changed, 120 insertions, 67 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7841ff01f9..7a2e2b7382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ python.EvaluatePython(udf, child, _) => - python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udfs, child, _) => + python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index a76009e7df..c9ab40a0a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{PythonFunction, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType} * we drain the queue to find the original input row. Note that if the Python process is way too * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) +case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil - private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => - val (fs, children) = collectFunctions(u) - (fs ++ Seq(udf.func), children) + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) - (Seq(udf.func), udf.children) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } @@ -69,19 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - val (pyFuncs, children) = collectFunctions(udf) - - val pickle = new Pickler - val currentRow = newMutableProjection(children, child.output)() - val fields = children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output)() + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) + val toBePickled = inputRows.map { inputRow => + queue.add(inputRow) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields + } }.toArray pickle.dumps(toBePickled) } @@ -89,19 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val row = new GenericMutableRow(1) + val mutableRow = new GenericMutableRow(1) val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } resultProj(joined(queue.poll(), row)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index da28ec4f53..f3d1c44b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -36,24 +36,28 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple. */ case class EvaluatePython( - udf: PythonUDF, + udfs: Seq[PythonUDF], child: LogicalPlan, - resultAttribute: AttributeReference) + resultAttribute: Seq[AttributeReference]) extends logical.UnaryNode { - def output: Seq[Attribute] = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output ++ resultAttribute // References should not include the produced attribute. - override def references: AttributeSet = udf.references + override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + new EvaluatePython(udfs, child, resultAttrs) + } def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() @@ -66,6 +70,16 @@ object EvaluatePython { } } + def needConversionInPython(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index c486ce18e8..0934cd135d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -47,10 +49,9 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { - expr.collect { - case udf: PythonUDF if canEvaluateInPython(udf) => udf - } + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -59,45 +60,43 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved) if (udfs.isEmpty) { // If there aren't any, we are done. plan } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) + } + if (validUdfs.nonEmpty) { + val evaluation = EvaluatePython(validUdfs, child) + attributeMap ++= validUdfs.zip(evaluation.resultAttribute) + evaluation + } else { + child + } } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + } + } + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) + }.withNewChildren(newChildren)) } } } |