aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-29 15:06:29 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-29 15:06:29 -0700
commita7a93a116dd9813853ba6f112beb7763931d2006 (patch)
tree9818f89be4fe960ccfa7585335bbebbff3666810
parente58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a (diff)
downloadspark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.gz
spark-a7a93a116dd9813853ba6f112beb7763931d2006.tar.bz2
spark-a7a93a116dd9813853ba6f112beb7763931d2006.zip
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu <davies@databricks.com> Closes #12014 from davies/py_udfs.
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala38
-rw-r--r--python/pyspark/sql/functions.py16
-rw-r--r--python/pyspark/sql/tests.py9
-rw-r--r--python/pyspark/worker.py33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala26
6 files changed, 116 insertions, 35 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f423b2ee56..0f579b4ef5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(func, bufferSize, reuse_worker)
+ val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
- func: PythonFunction,
+ funcs: Seq[PythonFunction],
bufferSize: Int,
- reuse_worker: Boolean)
+ reuse_worker: Boolean,
+ rowBased: Boolean)
extends Logging {
- private val envVars = func.envVars
- private val pythonExec = func.pythonExec
- private val accumulator = func.accumulator
+ // All the Python functions should have the same exec, version and envvars.
+ private val envVars = funcs.head.envVars
+ private val pythonExec = funcs.head.pythonExec
+ private val pythonVer = funcs.head.pythonVer
+
+ private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
def compute(
inputIterator: Iterator[_],
@@ -228,10 +232,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonVer = func.pythonVer
- private val pythonIncludes = func.pythonIncludes
- private val broadcastVars = func.broadcastVars
- private val command = func.command
+ private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
+ private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
setDaemon(true)
@@ -256,13 +258,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.size())
- for (include <- pythonIncludes.asScala) {
+ dataOut.writeInt(pythonIncludes.size)
+ for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
- val newBids = broadcastVars.asScala.map(_.id).toSet
+ val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +274,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
- for (broadcast <- broadcastVars.asScala) {
+ for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
@@ -282,8 +284,12 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(command.length)
- dataOut.write(command)
+ dataOut.writeInt(if (rowBased) 1 else 0)
+ dataOut.writeInt(funcs.length)
+ funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f5d959ef98..3211834226 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import since, SparkContext
-from pyspark.rdd import _wrap_function, ignore_unicode_prefix
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------
+def _wrap_function(sc, func, returnType):
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, returnType, ser)
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python
@@ -1662,14 +1670,12 @@ class UserDefinedFunction(object):
def _create_judf(self, name):
from pyspark.sql import SQLContext
- f, returnType = self.func, self.returnType # put them in closure `func`
- func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
- ser = AutoBatchedSerializer(PickleSerializer())
sc = SparkContext.getOrCreate()
- wrapped_func = _wrap_function(sc, func, ser, ser)
+ wrapped_func = _wrap_function(sc, self.func, self.returnType)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
+ f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1a5d422af9..84947560e7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_chained_python_udf(self):
+ self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+ self.assertEqual(row[0], 2)
+ [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+ self.assertEqual(row[0], 4)
+ [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+ self.assertEqual(row[0], 6)
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 42c2f8b759..0f05fe31aa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -50,6 +50,18 @@ def add_path(path):
sys.path.insert(1, path)
+def read_command(serializer, file):
+ command = serializer._read_with_length(file)
+ if isinstance(command, Broadcast):
+ command = serializer.loads(command.value)
+ return command
+
+
+def chain(f, g):
+ """chain two function together """
+ return lambda x: g(f(x))
+
+
def main(infile, outfile):
try:
boot_time = time.time()
@@ -95,10 +107,23 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- command = pickleSer._read_with_length(infile)
- if isinstance(command, Broadcast):
- command = pickleSer.loads(command.value)
- func, profiler, deserializer, serializer = command
+ row_based = read_int(infile)
+ num_commands = read_int(infile)
+ if row_based:
+ profiler = None # profiling is not supported for UDF
+ row_func = None
+ for i in range(num_commands):
+ f, returnType, deserializer = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ serializer = deserializer
+ func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
+ else:
+ assert num_commands == 1
+ func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+
init_time = time.time()
def process():
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 79e4491026..a76009e7df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.spark.TaskContext
-import org.apache.spark.api.python.PythonRunner
+import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}
@@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
def children: Seq[SparkPlan] = child :: Nil
+ private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
+ udf.children match {
+ case Seq(u: PythonUDF) =>
+ val (fs, children) = collectFunctions(u)
+ (fs ++ Seq(udf.func), children)
+ case children =>
+ // There should not be any other UDFs, or the children can't be evaluated directly.
+ assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+ (Seq(udf.func), udf.children)
+ }
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
+ val (pyFuncs, children) = collectFunctions(udf)
+
val pickle = new Pickler
- val currentRow = newMutableProjection(udf.children, child.output)()
- val fields = udf.children.map(_.dataType)
+ val currentRow = newMutableProjection(children, child.output)()
+ val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
@@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()
// Output iterator for results from Python.
- val outputIterator = new PythonRunner(
- udf.func,
- bufferSize,
- reuseWorker
- ).compute(inputIterator, context.partitionId(), context)
+ val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
+ .compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
val row = new GenericMutableRow(1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 6e76e9569f..c486ce18e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.python
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
*
+ * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
+ * or all the children could be evaluated in JVM).
+ *
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+
+ private def hasPythonUDF(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[PythonUDF]).isDefined
+ }
+
+ private def canEvaluateInPython(e: PythonUDF): Boolean = {
+ e.children match {
+ // single PythonUDF child could be chained and evaluated in Python
+ case Seq(u: PythonUDF) => canEvaluateInPython(u)
+ // Python UDF can't be evaluated directly in JVM
+ case children => !children.exists(hasPythonUDF)
+ }
+ }
+
+ private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
+ expr.collect {
+ case udf: PythonUDF if canEvaluateInPython(udf) => udf
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
+ val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan