aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-02-24 12:44:54 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-24 12:44:54 -0800
commita60f91284ceee64de13f04559ec19c13a820a133 (patch)
tree68d7d84620835d5e66cc3f94771a11655c4cbe2b
parentf92f53faeea020d80638a06752d69ca7a949cdeb (diff)
downloadspark-a60f91284ceee64de13f04559ec19c13a820a133.tar.gz
spark-a60f91284ceee64de13f04559ec19c13a820a133.tar.bz2
spark-a60f91284ceee64de13f04559ec19c13a820a133.zip
[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code
## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11342 from cloud-fan/python-clean.
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala37
-rw-r--r--python/pyspark/rdd.py23
-rw-r--r--python/pyspark/sql/context.py2
-rw-r--r--python/pyspark/sql/functions.py8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala15
8 files changed, 51 insertions, 63 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f12e2dfafa..05d1c31a08 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
private[spark] class PythonRDD(
parent: RDD[_],
- command: Array[Byte],
- envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean,
- pythonExec: String,
- pythonVer: String,
- broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]])
+ func: PythonFunction,
+ preservePartitoning: Boolean)
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
@@ -64,29 +58,37 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(
- command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
- bufferSize, reuse_worker)
+ val runner = new PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
-
/**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for a Python function, contains all necessary context to run the function in Python
+ * runner.
*/
-private[spark] class PythonRunner(
+private[spark] case class PythonFunction(
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
+
+/**
+ * A helper class to run Python UDFs in Spark.
+ */
+private[spark] class PythonRunner(
+ func: PythonFunction,
bufferSize: Int,
reuse_worker: Boolean)
extends Logging {
+ private val envVars = func.envVars
+ private val pythonExec = func.pythonExec
+ private val accumulator = func.accumulator
+
def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
@@ -225,6 +227,11 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
+ private val pythonVer = func.pythonVer
+ private val pythonIncludes = func.pythonIncludes
+ private val broadcastVars = func.broadcastVars
+ private val command = func.command
+
setDaemon(true)
/** Contains the exception thrown while writing the parent iterator to the Python process. */
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4eaf589ad5..37574cea0b 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2309,7 +2309,7 @@ class RDD(object):
yield row
-def _prepare_for_python_RDD(sc, command, obj=None):
+def _prepare_for_python_RDD(sc, command):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
@@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
return pickled_command, broadcast_vars, env, includes
+def _wrap_function(sc, func, deserializer, serializer, profiler=None):
+ assert deserializer, "deserializer should not be empty"
+ assert serializer, "serializer should not be empty"
+ command = (func, profiler, deserializer, serializer)
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
class PipelinedRDD(RDD):
"""
@@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD):
else:
profiler = None
- command = (self.func, profiler, self._prev_jrdd_deserializer,
- self._jrdd_deserializer)
- pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
- python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_cmd),
- env, includes, self.preservesPartitioning,
- self.ctx.pythonExec, self.ctx.pythonVer,
- bvars, self.ctx._javaAccumulator)
+ wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
+ self._jrdd_deserializer, profiler)
+ python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
+ self.preservesPartitioning)
self._jrdd_val = python_rdd.asJavaRDD()
if profiler:
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 89bf1443a6..87e32c04ea 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -29,7 +29,7 @@ else:
from py4j.protocol import Py4JError
from pyspark import since
-from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
+from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6894c27338..b30cc6799e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import since, SparkContext
-from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
+from pyspark.rdd import _wrap_function, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1645,16 +1645,14 @@ class UserDefinedFunction(object):
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
sc = SparkContext.getOrCreate()
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ wrapped_func = _wrap_function(sc, func, ser, ser)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer,
- broadcast_vars, sc._javaAccumulator, jdt)
+ name, wrapped_func, jdt)
return judf
def __del__(self):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index ecfc170bee..de01cbcb0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
s"""
| Registering new PythonUDF:
| name: $name
- | command: ${udf.command.toSeq}
- | envVars: ${udf.envVars}
- | pythonIncludes: ${udf.pythonIncludes}
- | pythonExec: ${udf.pythonExec}
+ | command: ${udf.func.command.toSeq}
+ | envVars: ${udf.func.envVars}
+ | pythonIncludes: ${udf.func.pythonIncludes}
+ | pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType}
""".stripMargin)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 00df019527..c65a7bcff8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// Output iterator for results from Python.
val outputIterator = new PythonRunner(
- udf.command,
- udf.envVars,
- udf.pythonIncludes,
- udf.pythonExec,
- udf.pythonVer,
- udf.broadcastVars,
- udf.accumulator,
+ udf.func,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 9aff0be716..0aa2785cb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.{Accumulator, Logging}
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.Logging
+import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
@@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType
*/
case class PythonUDF(
name: String,
- command: Array[Byte],
- envVars: java.util.Map[String, String],
- pythonIncludes: java.util.List[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[java.util.List[Array[Byte]]],
+ func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression with Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 79ac1c85c0..d301874c22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.Accumulator
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DataType
@@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType
*/
case class UserDefinedPythonFunction(
name: String,
- command: Array[Byte],
- envVars: java.util.Map[String, String],
- pythonIncludes: java.util.List[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[java.util.List[Array[Byte]]],
+ func: PythonFunction,
dataType: DataType) {
def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
- accumulator, dataType, e)
+ PythonUDF(name, func, dataType, e)
}
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */