From dc101b0e4e23dffddbc2f70d14a19fae5d87a328 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 15:55:09 -0800 Subject: [SPARK-5577] Python udf for DataFrame Author: Davies Liu Closes #4351 from davies/python_udf and squashes the following commits: d250692 [Davies Liu] fix conflict 34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 440f769 [Davies Liu] address comments f0a3121 [Davies Liu] track life cycle of broadcast f99b2e1 [Davies Liu] address comments 462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 7bccc3b [Davies Liu] python udf 58dee20 [Davies Liu] clean up --- python/pyspark/sql.py | 195 +++++++++++++++++++++++--------------------------- 1 file changed, 91 insertions(+), 104 deletions(-) (limited to 'python/pyspark/sql.py') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index a266cde51d..5b56b36bdc 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -51,7 +51,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _prepare_for_python_RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel @@ -1274,28 +1274,15 @@ class SQLContext(object): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, - AutoBatchedSerializer(PickleSerializer()), - AutoBatchedSerializer(PickleSerializer())) - ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - broadcast = self._sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = ListConverter().convert( - [x._jbroadcast for x in self._sc._pickled_broadcast_vars], - self._sc._gateway._gateway_client) - self._sc._pickled_broadcast_vars.clear() - env = MapConverter().convert(self._sc.environment, - self._sc._gateway._gateway_client) - includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) self._ssql_ctx.udf().registerPython(name, - bytearray(pickled_command), + bytearray(pickled_cmd), env, includes, self._sc.pythonExec, - broadcast_vars, + bvars, self._sc._javaAccumulator, returnType.json()) @@ -2077,9 +2064,9 @@ class DataFrame(object): """Return all column names and their data types as a list. >>> df.dtypes - [(u'age', 'IntegerType'), (u'name', 'StringType')] + [('age', 'integer'), ('name', 'string')] """ - return [(f.name, str(f.dataType)) for f in self.schema().fields] + return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] @property def columns(self): @@ -2194,7 +2181,7 @@ class DataFrame(object): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] - >>> df.select(df.name, (df.age + 10).As('age')).collect() + >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ if not cols: @@ -2295,25 +2282,13 @@ class DataFrame(object): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): - """ Return a new DataFrame by sampling a fraction of rows. - - >>> df.sample(False, 0.5, 10).collect() - [Row(age=2, name=u'Alice')] - """ - if seed is None: - jdf = self._jdf.sample(withReplacement, fraction) - else: - jdf = self._jdf.sample(withReplacement, fraction, seed) - return DataFrame(jdf, self.sql_ctx) - def addColumn(self, colName, col): """ Return a new :class:`DataFrame` by adding a column. >>> df.addColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ - return self.select('*', col.As(colName)) + return self.select('*', col.alias(colName)) # Having SchemaRDD for backward compatibility (for docs) @@ -2408,28 +2383,6 @@ class GroupedDataFrame(object): group.""" -SCALA_METHOD_MAPPINGS = { - '=': '$eq', - '>': '$greater', - '<': '$less', - '+': '$plus', - '-': '$minus', - '*': '$times', - '/': '$div', - '!': '$bang', - '@': '$at', - '#': '$hash', - '%': '$percent', - '^': '$up', - '&': '$amp', - '~': '$tilde', - '?': '$qmark', - '|': '$bar', - '\\': '$bslash', - ':': '$colon', -} - - def _create_column_from_literal(literal): sc = SparkContext._active_spark_context return sc._jvm.Dsl.lit(literal) @@ -2448,23 +2401,18 @@ def _to_java_column(col): return jcol -def _scalaMethod(name): - """ Translate operators into methodName in Scala - - >>> _scalaMethod('+') - '$plus' - >>> _scalaMethod('>=') - '$greater$eq' - >>> _scalaMethod('cast') - 'cast' - """ - return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) - - def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): - jc = getattr(self._jc, _scalaMethod(name))() + jc = getattr(self._jc, name)() + return Column(jc, self.sql_ctx) + _.__doc__ = doc + return _ + + +def _dsl_op(name, doc=''): + def _(self): + jc = getattr(self._sc._jvm.Dsl, name)(self._jc) return Column(jc, self.sql_ctx) _.__doc__ = doc return _ @@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"): """ def _(self, other): jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, _scalaMethod(name))(jc) + njc = getattr(self._jc, name)(jc) return Column(njc, self.sql_ctx) _.__doc__ = doc return _ @@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"): """ def _(self, other): jother = _create_column_from_literal(other) - jc = getattr(jother, _scalaMethod(name))(self._jc) + jc = getattr(jother, name)(self._jc) return Column(jc, self.sql_ctx) _.__doc__ = doc return _ @@ -2513,34 +2461,33 @@ class Column(DataFrame): super(Column, self).__init__(jc, sql_ctx) # arithmetic operators - __neg__ = _unary_op("unary_-") - __add__ = _bin_op("+") - __sub__ = _bin_op("-") - __mul__ = _bin_op("*") - __div__ = _bin_op("/") - __mod__ = _bin_op("%") - __radd__ = _bin_op("+") - __rsub__ = _reverse_op("-") - __rmul__ = _bin_op("*") - __rdiv__ = _reverse_op("/") - __rmod__ = _reverse_op("%") - __abs__ = _unary_op("abs") + __neg__ = _dsl_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") # logistic operators - __eq__ = _bin_op("===") - __ne__ = _bin_op("!==") - __lt__ = _bin_op("<") - __le__ = _bin_op("<=") - __ge__ = _bin_op(">=") - __gt__ = _bin_op(">") + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") # `and`, `or`, `not` cannot be overloaded in Python, # so use bitwise operators as boolean operators - __and__ = _bin_op('&&') - __or__ = _bin_op('||') - __invert__ = _unary_op('unary_!') - __rand__ = _bin_op("&&") - __ror__ = _bin_op("||") + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _dsl_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") # container operators __contains__ = _bin_op("contains") @@ -2582,24 +2529,20 @@ class Column(DataFrame): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - # `as` is keyword def alias(self, alias): """Return a alias for this column - >>> df.age.As("age2").collect() - [Row(age2=2), Row(age2=5)] >>> df.age.alias("age2").collect() [Row(age2=2), Row(age2=5)] """ return Column(getattr(self._jc, "as")(alias), self.sql_ctx) - As = alias def cast(self, dataType): """ Convert the column into type `dataType` - >>> df.select(df.age.cast("string").As('ages')).collect() + >>> df.select(df.age.cast("string").alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).As('ages')).collect() + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] """ if self.sql_ctx is None: @@ -2626,6 +2569,40 @@ def _aggregate_func(name, doc=""): return staticmethod(_) +class UserDefinedFunction(object): + def __init__(self, func, returnType): + self.func = func + self.returnType = returnType + self._broadcast = None + self._judf = self._create_judf() + + def _create_judf(self): + f = self.func # put it in closure `func` + func = lambda _, it: imap(lambda x: f(*x), it) + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, None, ser, ser) + sc = SparkContext._active_spark_context + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(self.returnType.json()) + judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + includes, sc.pythonExec, broadcast_vars, + sc._javaAccumulator, jdt) + return judf + + def __del__(self): + if self._broadcast is not None: + self._broadcast.unpersist() + self._broadcast = None + + def __call__(self, *cols): + sc = SparkContext._active_spark_context + jcols = ListConverter().convert([_to_java_column(c) for c in cols], + sc._gateway._gateway_client) + jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + return Column(jc) + + class Dsl(object): """ A collections of builtin aggregators @@ -2659,7 +2636,7 @@ class Dsl(object): """ Return a new Column for distinct count of (col, *cols) >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect() + >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context @@ -2674,7 +2651,7 @@ class Dsl(object): """ Return a new Column for approxiate distinct count of (col, *cols) >>> from pyspark.sql import Dsl - >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect() + >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect() [Row(c=2)] """ sc = SparkContext._active_spark_context @@ -2684,6 +2661,16 @@ class Dsl(object): jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd) return Column(jc) + @staticmethod + def udf(f, returnType=StringType()): + """Create a user defined function (UDF) + + >>> slen = Dsl.udf(lambda s: len(s), IntegerType()) + >>> df.select(slen(df.name).alias('slen')).collect() + [Row(slen=5), Row(slen=3)] + """ + return UserDefinedFunction(f, returnType) + def _test(): import doctest -- cgit v1.2.3