diff options
Diffstat (limited to 'sql')
4 files changed, 11 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ecfc170bee..de01cbcb0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { s""" | Registering new PythonUDF: | name: $name - | command: ${udf.command.toSeq} - | envVars: ${udf.envVars} - | pythonIncludes: ${udf.pythonIncludes} - | pythonExec: ${udf.pythonExec} + | command: ${udf.func.command.toSeq} + | envVars: ${udf.func.envVars} + | pythonIncludes: ${udf.func.pythonIncludes} + | pythonExec: ${udf.func.pythonExec} | dataType: ${udf.dataType} """.stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 00df019527..c65a7bcff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // Output iterator for results from Python. val outputIterator = new PythonRunner( - udf.command, - udf.envVars, - udf.pythonIncludes, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator, + udf.func, bufferSize, reuseWorker ).compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9aff0be716..0aa2785cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType */ case class PythonUDF( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType, children: Seq[Expression]) extends Expression with Unevaluable with NonSQLExpression with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 79ac1c85c0..d301874c22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.Accumulator -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.Column import org.apache.spark.sql.types.DataType @@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType */ case class UserDefinedPythonFunction( name: String, - command: Array[Byte], - envVars: java.util.Map[String, String], - pythonIncludes: java.util.List[String], - pythonExec: String, - pythonVer: String, - broadcastVars: java.util.List[Broadcast[PythonBroadcast]], - accumulator: Accumulator[java.util.List[Array[Byte]]], + func: PythonFunction, dataType: DataType) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, - accumulator, dataType, e) + PythonUDF(name, func, dataType, e) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ |