aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-29 15:06:29 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-29 15:06:29 -0700
commita7a93a116dd9813853ba6f112beb7763931d2006 (patch)
tree9818f89be4fe960ccfa7585335bbebbff3666810 /core/src/main/scala/org/apache
parente58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff)
downloadspark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz
spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2
spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu <davies@databricks.com> Closes #12014 from davies/py_udfs.
Diffstat (limited to 'core/src/main/scala/org/apache')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala38
1 files changed, 22 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f423b2ee56..0f579b4ef5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(func, bufferSize, reuse_worker)
+ val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
- func: PythonFunction,
+ funcs: Seq[PythonFunction],
bufferSize: Int,
- reuse_worker: Boolean)
+ reuse_worker: Boolean,
+ rowBased: Boolean)
extends Logging {
- private val envVars = func.envVars
- private val pythonExec = func.pythonExec
- private val accumulator = func.accumulator
+ // All the Python functions should have the same exec, version and envvars.
+ private val envVars = funcs.head.envVars
+ private val pythonExec = funcs.head.pythonExec
+ private val pythonVer = funcs.head.pythonVer
+
+ private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
def compute(
inputIterator: Iterator[_],
@@ -228,10 +232,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonVer = func.pythonVer
- private val pythonIncludes = func.pythonIncludes
- private val broadcastVars = func.broadcastVars
- private val command = func.command
+ private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
+ private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
setDaemon(true)
@@ -256,13 +258,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.size())
- for (include <- pythonIncludes.asScala) {
+ dataOut.writeInt(pythonIncludes.size)
+ for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
- val newBids = broadcastVars.asScala.map(_.id).toSet
+ val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +274,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
- for (broadcast <- broadcastVars.asScala) {
+ for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
@@ -282,8 +284,12 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(command.length)
- dataOut.write(command)
+ dataOut.writeInt(if (rowBased) 1 else 0)
+ dataOut.writeInt(funcs.length)
+ funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)