diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-09-03 19:08:39 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-09-03 19:08:39 -0700 |
commit | c5cbc49233193836b321cb6b77ce69dae798570b (patch) | |
tree | cbab071adad31a6492e99674a9666d4e15c8283f /sql | |
parent | 248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff) | |
download | spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2 spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip |
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2243 from davies/udf_broadcast and squashes the following commits:
7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala | 3 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala | 3 |
2 files changed, 5 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 0b48e9e659..0ea1105f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF @@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( @@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration { envVars, pythonIncludes, pythonExec, + broadcastVars, accumulator, dataType, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dc8be2456..0977da3e85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -42,6 +42,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { @@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, - Seq[Broadcast[Array[Byte]]](), + udf.broadcastVars, udf.accumulator ).mapPartitions { iter => val pickle = new Unpickler |