aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-04 15:55:09 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 15:55:09 -0800
commitdc101b0e4e23dffddbc2f70d14a19fae5d87a328 (patch)
treee436271c351a64caa4727661cd6143ba6e415fa6 /python
parente0490e271d078aa55d7c7583e2ba80337ed1b0c4 (diff)
downloadspark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.gz
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.bz2
spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.zip
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <davies@databricks.com> Closes #4351 from davies/python_udf and squashes the following commits: d250692 [Davies Liu] fix conflict 34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 440f769 [Davies Liu] address comments f0a3121 [Davies Liu] track life cycle of broadcast f99b2e1 [Davies Liu] address comments 462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf 7bccc3b [Davies Liu] python udf 58dee20 [Davies Liu] clean up
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py38
-rw-r--r--python/pyspark/sql.py195
2 files changed, 113 insertions, 120 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2f8a0edfe9..6e029bf7f1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2162,6 +2162,25 @@ class RDD(object):
yield row
+def _prepare_for_python_RDD(sc, command, obj=None):
+ # the serialized command will be compressed by broadcast
+ ser = CloudPickleSerializer()
+ pickled_command = ser.dumps(command)
+ if len(pickled_command) > (1 << 20): # 1M
+ broadcast = sc.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
+ # tracking the life cycle by obj
+ if obj is not None:
+ obj._broadcast = broadcast
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in sc._pickled_broadcast_vars],
+ sc._gateway._gateway_client)
+ sc._pickled_broadcast_vars.clear()
+ env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
+ includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
+ return pickled_command, broadcast_vars, env, includes
+
+
class PipelinedRDD(RDD):
"""
@@ -2228,25 +2247,12 @@ class PipelinedRDD(RDD):
command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
- # the serialized command will be compressed by broadcast
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- self._broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(self._broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx._gateway._gateway_client)
- self.ctx._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self.ctx.environment,
- self.ctx._gateway._gateway_client)
- includes = ListConverter().convert(self.ctx._python_includes,
- self.ctx._gateway._gateway_client)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_command),
+ bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator)
+ bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
if profiler:
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index a266cde51d..5b56b36bdc 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -51,7 +51,7 @@ from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
@@ -1274,28 +1274,15 @@ class SQLContext(object):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
- command = (func, None,
- AutoBatchedSerializer(PickleSerializer()),
- AutoBatchedSerializer(PickleSerializer()))
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- broadcast = self._sc.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
- self._sc._gateway._gateway_client)
- self._sc._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self._sc.environment,
- self._sc._gateway._gateway_client)
- includes = ListConverter().convert(self._sc._python_includes,
- self._sc._gateway._gateway_client)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
self._ssql_ctx.udf().registerPython(name,
- bytearray(pickled_command),
+ bytearray(pickled_cmd),
env,
includes,
self._sc.pythonExec,
- broadcast_vars,
+ bvars,
self._sc._javaAccumulator,
returnType.json())
@@ -2077,9 +2064,9 @@ class DataFrame(object):
"""Return all column names and their data types as a list.
>>> df.dtypes
- [(u'age', 'IntegerType'), (u'name', 'StringType')]
+ [('age', 'integer'), ('name', 'string')]
"""
- return [(f.name, str(f.dataType)) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
@property
def columns(self):
@@ -2194,7 +2181,7 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
- >>> df.select(df.name, (df.age + 10).As('age')).collect()
+ >>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
if not cols:
@@ -2295,25 +2282,13 @@ class DataFrame(object):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
- def sample(self, withReplacement, fraction, seed=None):
- """ Return a new DataFrame by sampling a fraction of rows.
-
- >>> df.sample(False, 0.5, 10).collect()
- [Row(age=2, name=u'Alice')]
- """
- if seed is None:
- jdf = self._jdf.sample(withReplacement, fraction)
- else:
- jdf = self._jdf.sample(withReplacement, fraction, seed)
- return DataFrame(jdf, self.sql_ctx)
-
def addColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
>>> df.addColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
- return self.select('*', col.As(colName))
+ return self.select('*', col.alias(colName))
# Having SchemaRDD for backward compatibility (for docs)
@@ -2408,28 +2383,6 @@ class GroupedDataFrame(object):
group."""
-SCALA_METHOD_MAPPINGS = {
- '=': '$eq',
- '>': '$greater',
- '<': '$less',
- '+': '$plus',
- '-': '$minus',
- '*': '$times',
- '/': '$div',
- '!': '$bang',
- '@': '$at',
- '#': '$hash',
- '%': '$percent',
- '^': '$up',
- '&': '$amp',
- '~': '$tilde',
- '?': '$qmark',
- '|': '$bar',
- '\\': '$bslash',
- ':': '$colon',
-}
-
-
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.Dsl.lit(literal)
@@ -2448,23 +2401,18 @@ def _to_java_column(col):
return jcol
-def _scalaMethod(name):
- """ Translate operators into methodName in Scala
-
- >>> _scalaMethod('+')
- '$plus'
- >>> _scalaMethod('>=')
- '$greater$eq'
- >>> _scalaMethod('cast')
- 'cast'
- """
- return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
-
-
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
- jc = getattr(self._jc, _scalaMethod(name))()
+ jc = getattr(self._jc, name)()
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _dsl_op(name, doc=''):
+ def _(self):
+ jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
"""
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
- njc = getattr(self._jc, _scalaMethod(name))(jc)
+ njc = getattr(self._jc, name)(jc)
return Column(njc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
"""
def _(self, other):
jother = _create_column_from_literal(other)
- jc = getattr(jother, _scalaMethod(name))(self._jc)
+ jc = getattr(jother, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -2513,34 +2461,33 @@ class Column(DataFrame):
super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
- __neg__ = _unary_op("unary_-")
- __add__ = _bin_op("+")
- __sub__ = _bin_op("-")
- __mul__ = _bin_op("*")
- __div__ = _bin_op("/")
- __mod__ = _bin_op("%")
- __radd__ = _bin_op("+")
- __rsub__ = _reverse_op("-")
- __rmul__ = _bin_op("*")
- __rdiv__ = _reverse_op("/")
- __rmod__ = _reverse_op("%")
- __abs__ = _unary_op("abs")
+ __neg__ = _dsl_op("negate")
+ __add__ = _bin_op("plus")
+ __sub__ = _bin_op("minus")
+ __mul__ = _bin_op("multiply")
+ __div__ = _bin_op("divide")
+ __mod__ = _bin_op("mod")
+ __radd__ = _bin_op("plus")
+ __rsub__ = _reverse_op("minus")
+ __rmul__ = _bin_op("multiply")
+ __rdiv__ = _reverse_op("divide")
+ __rmod__ = _reverse_op("mod")
# logistic operators
- __eq__ = _bin_op("===")
- __ne__ = _bin_op("!==")
- __lt__ = _bin_op("<")
- __le__ = _bin_op("<=")
- __ge__ = _bin_op(">=")
- __gt__ = _bin_op(">")
+ __eq__ = _bin_op("equalTo")
+ __ne__ = _bin_op("notEqual")
+ __lt__ = _bin_op("lt")
+ __le__ = _bin_op("leq")
+ __ge__ = _bin_op("geq")
+ __gt__ = _bin_op("gt")
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
- __and__ = _bin_op('&&')
- __or__ = _bin_op('||')
- __invert__ = _unary_op('unary_!')
- __rand__ = _bin_op("&&")
- __ror__ = _bin_op("||")
+ __and__ = _bin_op('and')
+ __or__ = _bin_op('or')
+ __invert__ = _dsl_op('not')
+ __rand__ = _bin_op("and")
+ __ror__ = _bin_op("or")
# container operators
__contains__ = _bin_op("contains")
@@ -2582,24 +2529,20 @@ class Column(DataFrame):
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
- # `as` is keyword
def alias(self, alias):
"""Return a alias for this column
- >>> df.age.As("age2").collect()
- [Row(age2=2), Row(age2=5)]
>>> df.age.alias("age2").collect()
[Row(age2=2), Row(age2=5)]
"""
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
- As = alias
def cast(self, dataType):
""" Convert the column into type `dataType`
- >>> df.select(df.age.cast("string").As('ages')).collect()
+ >>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
- >>> df.select(df.age.cast(StringType()).As('ages')).collect()
+ >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
if self.sql_ctx is None:
@@ -2626,6 +2569,40 @@ def _aggregate_func(name, doc=""):
return staticmethod(_)
+class UserDefinedFunction(object):
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
class Dsl(object):
"""
A collections of builtin aggregators
@@ -2659,7 +2636,7 @@ class Dsl(object):
""" Return a new Column for distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
@@ -2674,7 +2651,7 @@ class Dsl(object):
""" Return a new Column for approxiate distinct count of (col, *cols)
>>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
@@ -2684,6 +2661,16 @@ class Dsl(object):
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)
+ @staticmethod
+ def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
+
def _test():
import doctest