diff options
author | Davies Liu <davies@databricks.com> | 2016-03-29 15:06:29 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-29 15:06:29 -0700 |
commit | a7a93a116dd9813853ba6f112beb7763931d2006 (patch) | |
tree | 9818f89be4fe960ccfa7585335bbebbff3666810 /python/pyspark/worker.py | |
parent | e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff) | |
download | spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2 spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip |
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request?
This PR brings the support for chained Python UDFs, for example
```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```
Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.
For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#10 AS double(double(1))#9]
: +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
+- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
: +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
: +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
+- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
+- !BatchPythonEvaluation double(1), [pythonUDF#17]
+- Scan OneRowRelation[]
```
TODO: will support multiple unrelated Python UDFs in one batch (another PR).
## How was this patch tested?
Added new unit tests for chained UDFs.
Author: Davies Liu <davies@databricks.com>
Closes #12014 from davies/py_udfs.
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r-- | python/pyspark/worker.py | 33 |
1 files changed, 29 insertions, 4 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b759..0f05fe31aa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,18 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda x: g(f(x)) + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +107,23 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + row_based = read_int(infile) + num_commands = read_int(infile) + if row_based: + profiler = None # profiling is not supported for UDF + row_func = None + for i in range(num_commands): + f, returnType, deserializer = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + serializer = deserializer + func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + else: + assert num_commands == 1 + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): |