diff options
author | Davies Liu <davies@databricks.com> | 2015-07-20 12:14:47 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-20 12:14:47 -0700 |
commit | 9f913c4fd6f0f223fd378e453d5b9a87beda1ac4 (patch) | |
tree | b0e89b1b7f4c0617c7bab6b314b161fc5240c95b /sql | |
parent | 02181fb6d14833448fb5c501045655213d3cf340 (diff) | |
download | spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.gz spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.bz2 spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.zip |
[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type
This PR also remove the duplicated code between registerFunction and UserDefinedFunction.
cc JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #7450 from davies/fix_return_type and squashes the following commits:
e80bf9f [Davies Liu] remove debugging code
f94b1f6 [Davies Liu] fix mima
8f9c58b [Davies Liu] convert returned object from UDF into internal type
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala | 44 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala | 10 |
2 files changed, 15 insertions, 39 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index d35d37d017..7cd7421a51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.{Accumulator, Logging} -import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** @@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry - protected[sql] def registerPython( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - stringDataType: String): Unit = { + protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" | Registering new PythonUDF: | name: $name - | command: ${command.toSeq} - | envVars: $envVars - | pythonIncludes: $pythonIncludes - | pythonExec: $pythonExec - | dataType: $stringDataType + | command: ${udf.command.toSeq} + | envVars: ${udf.envVars} + | pythonIncludes: ${udf.pythonIncludes} + | pythonExec: ${udf.pythonExec} + | dataType: ${udf.dataType} """.stripMargin) - - val dataType = sqlContext.parseDataType(stringDataType) - - def builder(e: Seq[Expression]): PythonUDF = - PythonUDF( - name, - command, - envVars, - pythonIncludes, - pythonExec, - pythonVer, - broadcastVars, - accumulator, - dataType, - e) - - functionRegistry.registerFunction(name, builder) + functionRegistry.registerFunction(name, udf.builder) } // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index b14e00ab9b..0f8cd280b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction( accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { + def builder(e: Seq[Expression]): PythonUDF = { + PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, + accumulator, dataType, e) + } + /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, - broadcastVars, accumulator, dataType, exprs.map(_.expr)) + val udf = builder(exprs.map(_.expr)) Column(udf) } } |