diff options
6 files changed, 116 insertions, 35 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f423b2ee56..0f579b4ef5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(func, bufferSize, reuse_worker) + val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -81,14 +81,18 @@ private[spark] case class PythonFunction( * A helper class to run Python UDFs in Spark. */ private[spark] class PythonRunner( - func: PythonFunction, + funcs: Seq[PythonFunction], bufferSize: Int, - reuse_worker: Boolean) + reuse_worker: Boolean, + rowBased: Boolean) extends Logging { - private val envVars = func.envVars - private val pythonExec = func.pythonExec - private val accumulator = func.accumulator + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.envVars + private val pythonExec = funcs.head.pythonExec + private val pythonVer = funcs.head.pythonVer + + private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF def compute( inputIterator: Iterator[_], @@ -228,10 +232,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonVer = func.pythonVer - private val pythonIncludes = func.pythonIncludes - private val broadcastVars = func.broadcastVars - private val command = func.command + private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet + private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) setDaemon(true) @@ -256,13 +258,13 @@ private[spark] class PythonRunner( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size()) - for (include <- pythonIncludes.asScala) { + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.asScala.map(_.id).toSet + val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -272,7 +274,7 @@ private[spark] class PythonRunner( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars.asScala) { + for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -282,8 +284,12 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.writeInt(if (rowBased) 1 else 0) + dataOut.writeInt(funcs.length) + funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5d959ef98..3211834226 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _wrap_function, ignore_unicode_prefix +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- +def _wrap_function(sc, func, returnType): + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, returnType, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class UserDefinedFunction(object): """ User defined function in Python @@ -1662,14 +1670,12 @@ class UserDefinedFunction(object): def _create_judf(self, name): from pyspark.sql import SQLContext - f, returnType = self.func, self.returnType # put them in closure `func` - func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) sc = SparkContext.getOrCreate() - wrapped_func = _wrap_function(sc, func, ser, ser) + wrapped_func = _wrap_function(sc, self.func, self.returnType) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: + f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( name, wrapped_func, jdt) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9..84947560e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_python_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b759..0f05fe31aa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,18 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda x: g(f(x)) + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +107,23 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + row_based = read_int(infile) + num_commands = read_int(infile) + if row_based: + profiler = None # profiling is not supported for UDF + row_func = None + for i in range(num_commands): + f, returnType, deserializer = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + serializer = deserializer + func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + else: + assert num_commands == 1 + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 79e4491026..a76009e7df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.PythonRunner +import org.apache.spark.api.python.{PythonFunction, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} @@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (fs, children) = collectFunctions(u) + (fs ++ Seq(udf.func), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (Seq(udf.func), udf.children) + } + } + protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val (pyFuncs, children) = collectFunctions(udf) + val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) + val currentRow = newMutableProjection(children, child.output)() + val fields = children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) // Input iterator to Python: input rows are grouped so we send them in batches to Python. @@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.func, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) + .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler val row = new GenericMutableRow(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 6e76e9569f..c486ce18e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated * alone in a batch. * + * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs + * or all the children could be evaluated in JVM). + * * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + + private def hasPythonUDF(e: Expression): Boolean = { + e.find(_.isInstanceOf[PythonUDF]).isDefined + } + + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + // single PythonUDF child could be chained and evaluated in Python + case Seq(u: PythonUDF) => canEvaluateInPython(u) + // Python UDF can't be evaluated directly in JVM + case children => !children.exists(hasPythonUDF) + } + } + + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { + expr.collect { + case udf: PythonUDF if canEvaluateInPython(udf) => udf + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + val udfs = plan.expressions.flatMap(collectEvaluatableUDF) if (udfs.isEmpty) { // If there aren't any, we are done. plan |