diff options
author | Davies Liu <davies@databricks.com> | 2016-03-29 15:06:29 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-29 15:06:29 -0700 |
commit | a7a93a116dd9813853ba6f112beb7763931d2006 (patch) | |
tree | 9818f89be4fe960ccfa7585335bbebbff3666810 /core/src/main/scala | |
parent | e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff) | |
download | spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2 spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip |
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
Diffstat (limited to 'core/src/main/scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f423b2ee56..0f579b4ef5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(func, bufferSize, reuse_worker) + val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -81,14 +81,18 @@ private[spark] case class PythonFunction( * A helper class to run Python UDFs in Spark. */ private[spark] class PythonRunner( - func: PythonFunction, + funcs: Seq[PythonFunction], bufferSize: Int, - reuse_worker: Boolean) + reuse_worker: Boolean, + rowBased: Boolean) extends Logging { - private val envVars = func.envVars - private val pythonExec = func.pythonExec - private val accumulator = func.accumulator + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.envVars + private val pythonExec = funcs.head.pythonExec + private val pythonVer = funcs.head.pythonVer + + private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF def compute( inputIterator: Iterator[_], @@ -228,10 +232,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonVer = func.pythonVer - private val pythonIncludes = func.pythonIncludes - private val broadcastVars = func.broadcastVars - private val command = func.command + private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet + private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) setDaemon(true) @@ -256,13 +258,13 @@ private[spark] class PythonRunner( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size()) - for (include <- pythonIncludes.asScala) { + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.asScala.map(_.id).toSet + val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -272,7 +274,7 @@ private[spark] class PythonRunner( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars.asScala) { + for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -282,8 +284,12 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.writeInt(if (rowBased) 1 else 0) + dataOut.writeInt(funcs.length) + funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) |