aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-03 19:08:39 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-03 19:08:39 -0700
commitc5cbc49233193836b321cb6b77ce69dae798570b (patch)
treecbab071adad31a6492e99674a9666d4e15c8283f /sql/core
parent248067adbe90f93c7d5e23aa61b3072dfdf48a8a (diff)
downloadspark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.gz
spark-c5cbc49233193836b321cb6b77ce69dae798570b.tar.bz2
spark-c5cbc49233193836b321cb6b77ce69dae798570b.zip
[SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF
After this patch, broadcast can be used in Python UDF. Author: Davies Liu <davies.liu@gmail.com> Closes #2243 from davies/udf_broadcast and squashes the following commits: 7b88861 [Davies Liu] support broadcast in UDF
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala3
2 files changed, 5 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 0b48e9e659..0ea1105f08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.util.{List => JList, Map => JMap}
import org.apache.spark.Accumulator
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf}
import org.apache.spark.sql.execution.PythonUDF
@@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration {
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
log.debug(
@@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration {
envVars,
pythonIncludes,
pythonExec,
+ broadcastVars,
accumulator,
dataType,
e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 3dc8be2456..0977da3e85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -42,6 +42,7 @@ private[spark] case class PythonUDF(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
children: Seq[Expression]) extends Expression with SparkLogging {
@@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
udf.pythonIncludes,
false,
udf.pythonExec,
- Seq[Broadcast[Array[Byte]]](),
+ udf.broadcastVars,
udf.accumulator
).mapPartitions { iter =>
val pickle = new Unpickler