aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-02 16:33:48 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-02 16:33:48 -0700
commit158ad0bba9382fd494b4789b5628a9cec00cfa19 (patch)
tree761acc1c8694a167043fc8f45bfa49447d6c1f2d /python/pyspark/sql.py
parent4c477117bb1ffef463776c86f925d35036f96b7a (diff)
downloadspark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.gz
spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.bz2
spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.zip
[SPARK-2097][SQL] UDF Support
This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL. Scala: ```scala registerFunction("strLenScala", (_: String).length) sql("SELECT strLenScala('test')") ``` Python: ```python sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType()) sqlCtx.sql("SELECT strLenPython('test')") ``` Java: ```java sqlContext.registerFunction("stringLengthJava", new UDF1<String, Integer>() { Override public Integer call(String str) throws Exception { return str.length(); } }, DataType.IntegerType); sqlContext.sql("SELECT stringLengthJava('test')"); ``` Author: Michael Armbrust <michael@databricks.com> Closes #1063 from marmbrus/udfs and squashes the following commits: 9eda0fe [Michael Armbrust] newline 747c05e [Michael Armbrust] Add some scala UDF tests. d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 005d684 [Michael Armbrust] Fix naming and formatting. d14dac8 [Michael Armbrust] Fix last line of autogened java files. 8135c48 [Michael Armbrust] Move UDF unit tests to pyspark. 40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable. 7a83101 [Michael Armbrust] Drop toString 795fd15 [Michael Armbrust] Try to avoid capturing SQLContext. e54fb45 [Michael Armbrust] Docs and tests. 437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments. 01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 8e6c932 [Michael Armbrust] WIP 3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 6237c8d [Michael Armbrust] WIP 2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs. 0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python.
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py39
1 files changed, 38 insertions, 1 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f840475ffa..e7c35ac1ff 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -28,9 +28,13 @@ from array import array
from operator import itemgetter
from pyspark.rdd import RDD, PipelinedRDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
+
+from itertools import chain, ifilter, imap
from py4j.protocol import Py4JError
+from py4j.java_collections import ListConverter, MapConverter
+
__all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
@@ -932,6 +936,39 @@ class SQLContext:
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
+ def registerFunction(self, name, f, returnType=StringType()):
+ """Registers a lambda function as a UDF so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not given it default to a string and conversion will automatically
+ be done. For any other return type, the produced object must match the specified type.
+
+ >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
+ >>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
+ [Row(c0=u'4')]
+ >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
+ >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
+ [Row(c0=4)]
+ >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ [Row(c0=5)]
+ """
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ command = (func,
+ BatchedSerializer(PickleSerializer(), 1024),
+ BatchedSerializer(PickleSerializer(), 1024))
+ env = MapConverter().convert(self._sc.environment,
+ self._sc._gateway._gateway_client)
+ includes = ListConverter().convert(self._sc._python_includes,
+ self._sc._gateway._gateway_client)
+ self._ssql_ctx.registerPython(name,
+ bytearray(CloudPickleSerializer().dumps(command)),
+ env,
+ includes,
+ self._sc.pythonExec,
+ self._sc._javaAccumulator,
+ str(returnType))
+
def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}s.