aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-20 12:14:47 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-20 12:14:47 -0700
commit9f913c4fd6f0f223fd378e453d5b9a87beda1ac4 (patch)
treeb0e89b1b7f4c0617c7bab6b314b161fc5240c95b
parent02181fb6d14833448fb5c501045655213d3cf340 (diff)
downloadspark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.gz
spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.bz2
spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.zip
[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type
This PR also remove the duplicated code between registerFunction and UserDefinedFunction. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #7450 from davies/fix_return_type and squashes the following commits: e80bf9f [Davies Liu] remove debugging code f94b1f6 [Davies Liu] fix mima 8f9c58b [Davies Liu] convert returned object from UDF into internal type
-rw-r--r--project/MimaExcludes.scala4
-rw-r--r--python/pyspark/sql/context.py16
-rw-r--r--python/pyspark/sql/functions.py15
-rw-r--r--python/pyspark/sql/tests.py4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala10
6 files changed, 32 insertions, 61 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index dd85254749..a2595ff6c2 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -69,7 +69,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"),
// local function inside a method
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
+ "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
) ++ Seq(
// SPARK-8479 Add numNonzeros and numActives to Matrix.
ProblemFilters.exclude[MissingMethodProblem](
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c93a15bada..abb6522dde 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -34,6 +34,7 @@ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
+from pyspark.sql.functions import UserDefinedFunction
try:
import pandas
@@ -191,19 +192,8 @@ class SQLContext(object):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
"""
- func = lambda _, it: map(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
- self._ssql_ctx.udf().registerPython(name,
- bytearray(pickled_cmd),
- env,
- includes,
- self._sc.pythonExec,
- self._sc.pythonVer,
- bvars,
- self._sc._javaAccumulator,
- returnType.json())
+ udf = UserDefinedFunction(f, returnType, name)
+ self._ssql_ctx.udf().registerPython(name, udf._judf)
def _inferSchemaFromList(self, data):
"""
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index fd5a3ba8ad..031745a1c4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -801,23 +801,24 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
- def __init__(self, func, returnType):
+ def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
self._broadcast = None
- self._judf = self._create_judf()
+ self._judf = self._create_judf(name)
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: map(lambda x: f(*x), it)
+ def _create_judf(self, name):
+ f, returnType = self.func, self.returnType # put them in closure `func`
+ func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
- fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
- judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+ if name is None:
+ name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+ judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7a55d801e4..ea821f486f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_udf_with_udt(self):
- from pyspark.sql.tests import ExamplePoint
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sc.parallelize([row]).toDF()
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+ self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index d35d37d017..7cd7421a51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.{Accumulator, Logging}
-import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
-import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
/**
@@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
private val functionRegistry = sqlContext.functionRegistry
- protected[sql] def registerPython(
- name: String,
- command: Array[Byte],
- envVars: JMap[String, String],
- pythonIncludes: JList[String],
- pythonExec: String,
- pythonVer: String,
- broadcastVars: JList[Broadcast[PythonBroadcast]],
- accumulator: Accumulator[JList[Array[Byte]]],
- stringDataType: String): Unit = {
+ protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
log.debug(
s"""
| Registering new PythonUDF:
| name: $name
- | command: ${command.toSeq}
- | envVars: $envVars
- | pythonIncludes: $pythonIncludes
- | pythonExec: $pythonExec
- | dataType: $stringDataType
+ | command: ${udf.command.toSeq}
+ | envVars: ${udf.envVars}
+ | pythonIncludes: ${udf.pythonIncludes}
+ | pythonExec: ${udf.pythonExec}
+ | dataType: ${udf.dataType}
""".stripMargin)
-
- val dataType = sqlContext.parseDataType(stringDataType)
-
- def builder(e: Seq[Expression]): PythonUDF =
- PythonUDF(
- name,
- command,
- envVars,
- pythonIncludes,
- pythonExec,
- pythonVer,
- broadcastVars,
- accumulator,
- dataType,
- e)
-
- functionRegistry.registerFunction(name, builder)
+ functionRegistry.registerFunction(name, udf.builder)
}
// scalastyle:off
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index b14e00ab9b..0f8cd280b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Accumulator
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
@@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction(
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {
+ def builder(e: Seq[Expression]): PythonUDF = {
+ PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
+ accumulator, dataType, e)
+ }
+
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
- val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
- broadcastVars, accumulator, dataType, exprs.map(_.expr))
+ val udf = builder(exprs.map(_.expr))
Column(udf)
}
}