aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-31 16:40:20 -0700
committerDavies Liu <davies.liu@gmail.com>2016-03-31 16:40:20 -0700
commitf0afafdc5dfee80d7e5cd2fc1fa8187def7f262d (patch)
treeb374ad4a7c98e11f8f85fbd44618422bd4fe6a1b
parent8de201baedc8e839e06098c536ba31b3dafd54b5 (diff)
downloadspark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.gz
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.tar.bz2
spark-f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d.zip
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch
## What changes were proposed in this pull request? This PR support multiple Python UDFs within single batch, also improve the performance. ```python >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType()) >>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType()) >>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)] +- OneRowRelation$ == Analyzed Logical Plan == double(add(1, 2)): int, add(double(2), 1): int Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Optimized Logical Plan == Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] : +- INPUT +- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18] +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- Scan OneRowRelation[] ``` ## How was this patch tested? Added new tests. Using the following script to benchmark 1, 2 and 3 udfs, ``` df = sqlContext.range(1, 1 << 23, 1, 4) double = F.udf(lambda x: x * 2, LongType()) print df.select(double(df.id)).count() print df.select(double(df.id), double(df.id + 1)).count() print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count() ``` Here is the results: N | Before | After | speed up ---- |------------ | -------------|------ 1 | 22 s | 7 s | 3.1X 2 | 38 s | 13 s | 2.9X 3 | 58 s | 16 s | 3.6X This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering). Author: Davies Liu <davies@databricks.com> Closes #12057 from davies/multi_udfs.
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala64
-rw-r--r--python/pyspark/sql/functions.py3
-rw-r--r--python/pyspark/sql/tests.py12
-rw-r--r--python/pyspark/worker.py68
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala78
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala77
8 files changed, 233 insertions, 101 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0f579b4ef5..6faa03c12b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
+ val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])
/**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for chained Python functions (from bottom to top).
+ * @param funcs
+ */
+private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+
+private[spark] object PythonRunner {
+ def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
+ new PythonRunner(
+ Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+ }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
*/
private[spark] class PythonRunner(
- funcs: Seq[PythonFunction],
+ funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
- rowBased: Boolean)
+ isUDF: Boolean,
+ argOffsets: Array[Array[Int]])
extends Logging {
+ require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
// All the Python functions should have the same exec, version and envvars.
- private val envVars = funcs.head.envVars
- private val pythonExec = funcs.head.pythonExec
- private val pythonVer = funcs.head.pythonVer
+ private val envVars = funcs.head.funcs.head.envVars
+ private val pythonExec = funcs.head.funcs.head.pythonExec
+ private val pythonVer = funcs.head.funcs.head.pythonVer
- private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
+ // TODO: support accumulator in multiple UDF
+ private val accumulator = funcs.head.funcs.head.accumulator
def compute(
inputIterator: Iterator[_],
@@ -232,8 +252,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
- private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
+ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+ private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
@@ -284,11 +304,25 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(if (rowBased) 1 else 0)
- dataOut.writeInt(funcs.length)
- funcs.foreach { f =>
- dataOut.writeInt(f.command.length)
- dataOut.write(f.command)
+ if (isUDF) {
+ dataOut.writeInt(1)
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+ dataOut.writeInt(offsets.length)
+ offsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
+ }
+ } else {
+ dataOut.writeInt(0)
+ val command = funcs.head.funcs.head.command
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3211834226..3b20ba5177 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------
def _wrap_function(sc, func, returnType):
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, returnType, ser)
+ command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 84947560e7..536ef55251 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,7 +305,7 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
- def test_chained_python_udf(self):
+ def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
@@ -314,6 +314,16 @@ class SQLTests(ReusedPySparkTestCase):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)
+ def test_multiple_udfs(self):
+ self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.assertEqual(tuple(row), (6, 5))
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0f05fe31aa..cf47ab8f96 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -59,7 +59,54 @@ def read_command(serializer, file):
def chain(f, g):
"""chain two function together """
- return lambda x: g(f(x))
+ return lambda *a: g(f(*a))
+
+
+def wrap_udf(f, return_type):
+ if return_type.needConversion():
+ toInternal = return_type.toInternal
+ return lambda *a: toInternal(f(*a))
+ else:
+ return lambda *a: f(*a)
+
+
+def read_single_udf(pickleSer, infile):
+ num_arg = read_int(infile)
+ arg_offsets = [read_int(infile) for i in range(num_arg)]
+ row_func = None
+ for i in range(read_int(infile)):
+ f, return_type = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ # the last returnType will be the return type of UDF
+ return arg_offsets, wrap_udf(row_func, return_type)
+
+
+def read_udfs(pickleSer, infile):
+ num_udfs = read_int(infile)
+ if num_udfs == 1:
+ # fast path for single UDF
+ _, udf = read_single_udf(pickleSer, infile)
+ mapper = lambda a: udf(*a)
+ else:
+ udfs = {}
+ call_udf = []
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ # Create function like this:
+ # lambda a: (f0(a0), f1(a1, a2), f2(a3))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
+
+ func = lambda _, it: map(mapper, it)
+ ser = BatchedSerializer(PickleSerializer(), 100)
+ # profiling is not supported for UDF
+ return func, None, ser, ser
def main(infile, outfile):
@@ -107,21 +154,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- row_based = read_int(infile)
- num_commands = read_int(infile)
- if row_based:
- profiler = None # profiling is not supported for UDF
- row_func = None
- for i in range(num_commands):
- f, returnType, deserializer = read_command(pickleSer, infile)
- if row_func is None:
- row_func = f
- else:
- row_func = chain(row_func, f)
- serializer = deserializer
- func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
+ is_sql_udf = read_int(infile)
+ if is_sql_udf:
+ func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
- assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
init_time = time.time()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7841ff01f9..7a2e2b7382 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
- case e @ python.EvaluatePython(udf, child, _) =>
- python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
+ case e @ python.EvaluatePython(udfs, child, _) =>
+ python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index a76009e7df..c9ab40a0a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.spark.TaskContext
-import org.apache.spark.api.python.{PythonFunction, PythonRunner}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
@@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {
def children: Seq[SparkPlan] = child :: Nil
- private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
+ private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
- val (fs, children) = collectFunctions(u)
- (fs ++ Seq(udf.func), children)
+ val (chained, children) = collectFunctions(u)
+ (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
- (Seq(udf.func), udf.children)
+ (ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}
@@ -69,19 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
- val (pyFuncs, children) = collectFunctions(udf)
-
- val pickle = new Pickler
- val currentRow = newMutableProjection(children, child.output)()
- val fields = children.map(_.dataType)
- val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+ val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+ // flatten all the arguments
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+ val projection = newMutableProjection(allInputs, child.output)()
+ val schema = StructType(dataTypes.map(dt => StructField("", dt)))
+ val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
+ val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
- val toBePickled = inputRows.map { row =>
- queue.add(row)
- EvaluatePython.toJava(currentRow(row), schema)
+ val toBePickled = inputRows.map { inputRow =>
+ queue.add(inputRow)
+ val row = projection(inputRow)
+ if (needConversion) {
+ EvaluatePython.toJava(row, schema)
+ } else {
+ // fast path for these types that does not need conversion in Python
+ val fields = new Array[Any](row.numFields)
+ var i = 0
+ while (i < row.numFields) {
+ val dt = dataTypes(i)
+ fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+ i += 1
+ }
+ fields
+ }
}.toArray
pickle.dumps(toBePickled)
}
@@ -89,19 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()
// Output iterator for results from Python.
- val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
+ val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
.compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
- val row = new GenericMutableRow(1)
+ val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
+ val resultType = if (udfs.length == 1) {
+ udfs.head.dataType
+ } else {
+ StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
+ }
val resultProj = UnsafeProjection.create(output, output)
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
- row(0) = EvaluatePython.fromJava(result, udf.dataType)
+ val row = if (udfs.length == 1) {
+ // fast path for single UDF
+ mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+ mutableRow
+ } else {
+ EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+ }
resultProj(joined(queue.poll(), row))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index da28ec4f53..f3d1c44b25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -36,24 +36,28 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
+ * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
*/
case class EvaluatePython(
- udf: PythonUDF,
+ udfs: Seq[PythonUDF],
child: LogicalPlan,
- resultAttribute: AttributeReference)
+ resultAttribute: Seq[AttributeReference])
extends logical.UnaryNode {
- def output: Seq[Attribute] = child.output :+ resultAttribute
+ def output: Seq[Attribute] = child.output ++ resultAttribute
// References should not include the produced attribute.
- override def references: AttributeSet = udf.references
+ override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
}
object EvaluatePython {
- def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
- new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
+ def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
+ val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+ AttributeReference(s"pythonUDF$i", u.dataType)()
+ }
+ new EvaluatePython(udfs, child, resultAttrs)
+ }
def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
@@ -66,6 +70,16 @@ object EvaluatePython {
}
}
+ def needConversionInPython(dt: DataType): Boolean = dt match {
+ case DateType | TimestampType => true
+ case _: StructType => true
+ case _: UserDefinedType[_] => true
+ case ArrayType(elementType, _) => needConversionInPython(elementType)
+ case MapType(keyType, valueType, _) =>
+ needConversionInPython(keyType) || needConversionInPython(valueType)
+ case _ => false
+ }
+
/**
* Helper for converting from Catalyst type to java type suitable for Pyrolite.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index c486ce18e8..0934cd135d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.sql.catalyst.expressions.Expression
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -47,10 +49,9 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
}
}
- private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
- expr.collect {
- case udf: PythonUDF if canEvaluateInPython(udf) => udf
- }
+ private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
+ case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf)
+ case e => e.children.flatMap(collectEvaluatableUDF)
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
@@ -59,45 +60,43 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
} else {
- // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
- // If there is more than one, we will add another evaluation operator in a subsequent pass.
- udfs.find(_.resolved) match {
- case Some(udf) =>
- var evaluation: EvaluatePython = null
-
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- if (udf.references.subsetOf(child.outputSet)) {
- evaluation = EvaluatePython(udf, child)
- evaluation
- } else if (udf.references.intersect(child.outputSet).nonEmpty) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- child
- }
- }
-
- assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
-
- // Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
- }.withNewChildren(newChildren))
-
- case None =>
- // If there is no Python UDF that is resolved, skip this round.
- plan
+ val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Pick the UDF we are going to evaluate
+ val validUdfs = udfs.filter { case udf =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ udf.references.subsetOf(child.outputSet)
+ }
+ if (validUdfs.nonEmpty) {
+ val evaluation = EvaluatePython(validUdfs, child)
+ attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
+ evaluation
+ } else {
+ child
+ }
}
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ udfs.filterNot(attributeMap.contains).foreach { udf =>
+ if (udf.references.subsetOf(plan.inputSet)) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
+ }
+ }
+
+ // Trim away the new UDF value if it was only used for filtering or something.
+ logical.Project(
+ plan.output,
+ plan.transformExpressions {
+ case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
+ }.withNewChildren(newChildren))
}
}
}