aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-02-24 12:44:54 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-24 12:44:54 -0800
commita60f91284ceee64de13f04559ec19c13a820a133 (patch)
tree68d7d84620835d5e66cc3f94771a11655c4cbe2b /sql
parentf92f53faeea020d80638a06752d69ca7a949cdeb (diff)
downloadspark-a60f91284ceee64de13f04559ec19c13a820a133.tar.gz
spark-a60f91284ceee64de13f04559ec19c13a820a133.tar.bz2
spark-a60f91284ceee64de13f04559ec19c13a820a133.zip
[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code
## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11342 from cloud-fan/python-clean.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala15
4 files changed, 11 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index ecfc170bee..de01cbcb0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
s"""
| Registering new PythonUDF:
| name: $name
- | command: ${udf.command.toSeq}
- | envVars: ${udf.envVars}
- | pythonIncludes: ${udf.pythonIncludes}
- | pythonExec: ${udf.pythonExec}
+ | command: ${udf.func.command.toSeq}
+ | envVars: ${udf.func.envVars}
+ | pythonIncludes: ${udf.func.pythonIncludes}
+ | pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType}
""".stripMargin)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 00df019527..c65a7bcff8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// Output iterator for results from Python.
val outputIterator = new PythonRunner(
- udf.command,
- udf.envVars,
- udf.pythonIncludes,
- udf.pythonExec,
- udf.pythonVer,
- udf.broadcastVars,
- udf.accumulator,
+ udf.func,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 9aff0be716..0aa2785cb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.{Accumulator, Logging}
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.Logging
+import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
@@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType
*/
case class PythonUDF(
name: String,
- command: Array[Byte],
- envVars: java.util.Map[String, String],
- pythonIncludes: java.util.List[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[java.util.List[Array[Byte]]],
+ func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression with Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 79ac1c85c0..d301874c22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.Accumulator
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DataType
@@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType
*/
case class UserDefinedPythonFunction(
name: String,
- command: Array[Byte],
- envVars: java.util.Map[String, String],
- pythonIncludes: java.util.List[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[java.util.List[Array[Byte]]],
+ func: PythonFunction,
dataType: DataType) {
def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
- accumulator, dataType, e)
+ PythonUDF(name, func, dataType, e)
}
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */