diff options
author | Michael Armbrust <michael@databricks.com> | 2014-08-02 16:33:48 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-02 16:33:48 -0700 |
commit | 158ad0bba9382fd494b4789b5628a9cec00cfa19 (patch) | |
tree | 761acc1c8694a167043fc8f45bfa49447d6c1f2d /python | |
parent | 4c477117bb1ffef463776c86f925d35036f96b7a (diff) | |
download | spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.gz spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.tar.bz2 spark-158ad0bba9382fd494b4789b5628a9cec00cfa19.zip |
[SPARK-2097][SQL] UDF Support
This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL.
Scala:
```scala
registerFunction("strLenScala", (_: String).length)
sql("SELECT strLenScala('test')")
```
Python:
```python
sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType())
sqlCtx.sql("SELECT strLenPython('test')")
```
Java:
```java
sqlContext.registerFunction("stringLengthJava", new UDF1<String, Integer>() {
Override
public Integer call(String str) throws Exception {
return str.length();
}
}, DataType.IntegerType);
sqlContext.sql("SELECT stringLengthJava('test')");
```
Author: Michael Armbrust <michael@databricks.com>
Closes #1063 from marmbrus/udfs and squashes the following commits:
9eda0fe [Michael Armbrust] newline
747c05e [Michael Armbrust] Add some scala UDF tests.
d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs
005d684 [Michael Armbrust] Fix naming and formatting.
d14dac8 [Michael Armbrust] Fix last line of autogened java files.
8135c48 [Michael Armbrust] Move UDF unit tests to pyspark.
40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs
6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable.
7a83101 [Michael Armbrust] Drop toString
795fd15 [Michael Armbrust] Try to avoid capturing SQLContext.
e54fb45 [Michael Armbrust] Docs and tests.
437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments.
01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs
8e6c932 [Michael Armbrust] WIP
3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs
6237c8d [Michael Armbrust] WIP
2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs.
0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql.py | 39 |
1 files changed, 38 insertions, 1 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f840475ffa..e7c35ac1ff 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -28,9 +28,13 @@ from array import array from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer + +from itertools import chain, ifilter, imap from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter + __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -932,6 +936,39 @@ class SQLContext: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + [Row(c0=5)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, + BatchedSerializer(PickleSerializer(), 1024), + BatchedSerializer(PickleSerializer(), 1024)) + env = MapConverter().convert(self._sc.environment, + self._sc._gateway._gateway_client) + includes = ListConverter().convert(self._sc._python_includes, + self._sc._gateway._gateway_client) + self._ssql_ctx.registerPython(name, + bytearray(CloudPickleSerializer().dumps(command)), + env, + includes, + self._sc.pythonExec, + self._sc._javaAccumulator, + str(returnType)) + def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}s. |