aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08 /python/pyspark
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/tests.py2
-rw-r--r--python/pyspark/sql/__init__.py3
-rw-r--r--python/pyspark/sql/context.py34
-rw-r--r--python/pyspark/sql/dataframe.py221
-rw-r--r--python/pyspark/sql/functions.py170
-rw-r--r--python/pyspark/sql/tests.py38
6 files changed, 270 insertions, 198 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 49e5c9d58e..06207a076e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -335,7 +335,7 @@ class VectorUDTTests(PySparkTestCase):
sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
srdd = sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ schema = srdd.schema
field = [f for f in schema.fields if f.name == "features"][0]
self.assertEqual(field.dataType, self.udt)
vectors = srdd.map(lambda p: p.features).collect()
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 0a5ba00393..b9ffd6945e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -34,9 +34,8 @@ public classes of Spark SQL:
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
-from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
+from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
- 'Dsl',
]
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 082f1b691b..7683c1b4df 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -38,6 +38,25 @@ except ImportError:
__all__ = ["SQLContext", "HiveContext"]
+def _monkey_patch_RDD(sqlCtx):
+ def toDF(self, schema=None, sampleRatio=None):
+ """
+ Convert current :class:`RDD` into a :class:`DataFrame`
+
+ This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)`
+
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd.toDF().collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ return sqlCtx.createDataFrame(self, schema, sampleRatio)
+
+ RDD.toDF = toDF
+
+
class SQLContext(object):
"""Main entry point for Spark SQL functionality.
@@ -49,15 +68,20 @@ class SQLContext(object):
def __init__(self, sparkContext, sqlContext=None):
"""Create a new SQLContext.
+ It will add a method called `toDF` to :class:`RDD`, which could be
+ used to convert an RDD into a DataFrame, it's a shorthand for
+ :func:`SQLContext.createDataFrame`.
+
:param sparkContext: The SparkContext to wrap.
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
>>> from datetime import datetime
+ >>> sqlCtx = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.createDataFrame(allTypes)
+ >>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -70,6 +94,7 @@ class SQLContext(object):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._scala_SQLContext = sqlContext
+ _monkey_patch_RDD(self)
@property
def _ssql_ctx(self):
@@ -442,7 +467,7 @@ class SQLContext(object):
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema)
>>> sqlCtx.registerRDDAsTable(df3, "table2")
>>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
@@ -495,7 +520,7 @@ class SQLContext(object):
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema)
>>> sqlCtx.registerRDDAsTable(df3, "table2")
>>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
@@ -800,7 +825,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
- globs['df'] = sqlCtx.createDataFrame(rdd)
+ _monkey_patch_RDD(sqlCtx)
+ globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b6f052ee44..1438fe5285 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -21,21 +21,19 @@ import warnings
import random
import os
from tempfile import NamedTemporaryFile
-from itertools import imap
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _prepare_for_python_RDD
-from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
- UTF8Deserializer
+from pyspark.rdd import RDD
+from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
-__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]
class DataFrame(object):
@@ -76,6 +74,7 @@ class DataFrame(object):
self.sql_ctx = sql_ctx
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
+ self._schema = None # initialized lazily
@property
def rdd(self):
@@ -86,7 +85,7 @@ class DataFrame(object):
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
- schema = self.schema()
+ schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
@@ -216,14 +215,17 @@ class DataFrame(object):
self._sc._gateway._gateway_client)
self._jdf.save(source, jmode, joptions)
+ @property
def schema(self):
"""Returns the schema of this DataFrame (represented by
a L{StructType}).
- >>> df.schema()
+ >>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
- return _parse_datatype_json_string(self._jdf.schema().json())
+ if self._schema is None:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ return self._schema
def printSchema(self):
"""Prints out the schema in the tree format.
@@ -284,7 +286,7 @@ class DataFrame(object):
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
- cls = _create_cls(self.schema())
+ cls = _create_cls(self.schema)
return [cls(r) for r in rs]
def limit(self, num):
@@ -310,14 +312,26 @@ class DataFrame(object):
return self.limit(num).collect()
def map(self, f):
- """ Return a new RDD by applying a function to each Row, it's a
- shorthand for df.rdd.map()
+ """ Return a new RDD by applying a function to each Row
+
+ It's a shorthand for df.rdd.map()
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
"""
return self.rdd.map(f)
+ def flatMap(self, f):
+ """ Return a new RDD by first applying a function to all elements of this,
+ and then flattening the results.
+
+ It's a shorthand for df.rdd.flatMap()
+
+ >>> df.flatMap(lambda p: p.name).collect()
+ [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
+ """
+ return self.rdd.flatMap(f)
+
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.
@@ -378,21 +392,6 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
- # def takeSample(self, withReplacement, num, seed=None):
- # """Return a fixed-size sampled subset of this DataFrame.
- #
- # >>> df = sqlCtx.inferSchema(rdd)
- # >>> df.takeSample(False, 2, 97)
- # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
- # """
- # seed = seed if seed is not None else random.randint(0, sys.maxint)
- # with SCCallSiteSync(self.context) as css:
- # bytesInJava = self._jdf \
- # .takeSampleToPython(withReplacement, num, long(seed)) \
- # .iterator()
- # cls = _create_cls(self.schema())
- # return map(cls, self._collect_iterator_through_file(bytesInJava))
-
@property
def dtypes(self):
"""Return all column names and their data types as a list.
@@ -400,7 +399,7 @@ class DataFrame(object):
>>> df.dtypes
[('age', 'int'), ('name', 'string')]
"""
- return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
def columns(self):
@@ -409,7 +408,7 @@ class DataFrame(object):
>>> df.columns
[u'age', u'name']
"""
- return [f.name for f in self.schema().fields]
+ return [f.name for f in self.schema.fields]
def join(self, other, joinExprs=None, joinType=None):
"""
@@ -586,8 +585,8 @@ class DataFrame(object):
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.min(df.age)).collect()
+ >>> from pyspark.sql import functions as F
+ >>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
@@ -616,18 +615,18 @@ class DataFrame(object):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
- def addColumn(self, colName, col):
+ def withColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
- >>> df.addColumn('age2', df.age + 2).collect()
+ >>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))
- def renameColumn(self, existing, new):
+ def withColumnRenamed(self, existing, new):
""" Rename an existing column to a new name
- >>> df.renameColumn('age', 'age2').collect()
+ >>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
@@ -635,11 +634,11 @@ class DataFrame(object):
for c in self.columns]
return self.select(*cols)
- def to_pandas(self):
+ def toPandas(self):
"""
Collect all the rows and return a `pandas.DataFrame`.
- >>> df.to_pandas() # doctest: +SKIP
+ >>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
@@ -687,10 +686,11 @@ class GroupedData(object):
name to aggregate methods.
>>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"age": "max"}).collect()
- [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
- >>> from pyspark.sql import Dsl
- >>> gdf.agg(Dsl.min(df.age)).collect()
+ >>> gdf.agg({"*": "count"}).collect()
+ [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+
+ >>> from pyspark.sql import functions as F
+ >>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
@@ -742,12 +742,12 @@ class GroupedData(object):
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.lit(literal)
+ return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.col(name)
+ return sc._jvm.functions.col(name)
def _to_java_column(col):
@@ -767,9 +767,9 @@ def _unary_op(name, doc="unary operator"):
return _
-def _dsl_op(name, doc=''):
+def _func_op(name, doc=''):
def _(self):
- jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+ jc = getattr(self._sc._jvm.functions, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -818,7 +818,7 @@ class Column(DataFrame):
super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
- __neg__ = _dsl_op("negate")
+ __neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
@@ -842,7 +842,7 @@ class Column(DataFrame):
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
- __invert__ = _dsl_op('not')
+ __invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
@@ -920,11 +920,11 @@ class Column(DataFrame):
else:
return 'Column<%s>' % self._jdf.toString()
- def to_pandas(self):
+ def toPandas(self):
"""
Return a pandas.Series from the column
- >>> df.age.to_pandas() # doctest: +SKIP
+ >>> df.age.toPandas() # doctest: +SKIP
0 2
1 5
dtype: int64
@@ -934,123 +934,6 @@ class Column(DataFrame):
return pd.Series(data)
-def _aggregate_func(name, doc=""):
- """ Create a function for aggregator by name"""
- def _(col):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
- return Column(jc)
- _.__name__ = name
- _.__doc__ = doc
- return staticmethod(_)
-
-
-class UserDefinedFunction(object):
- def __init__(self, func, returnType):
- self.func = func
- self.returnType = returnType
- self._broadcast = None
- self._judf = self._create_judf()
-
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
- sc._javaAccumulator, jdt)
- return judf
-
- def __del__(self):
- if self._broadcast is not None:
- self._broadcast.unpersist()
- self._broadcast = None
-
- def __call__(self, *cols):
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
-
-class Dsl(object):
- """
- A collections of builtin aggregators
- """
- DSLS = {
- 'lit': 'Creates a :class:`Column` of literal value.',
- 'col': 'Returns a :class:`Column` based on the given column name.',
- 'column': 'Returns a :class:`Column` based on the given column name.',
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
- 'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
-
- 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
- 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
- 'count': 'Aggregate function: returns the number of items in a group.',
- 'sum': 'Aggregate function: returns the sum of all values in the expression.',
- 'avg': 'Aggregate function: returns the average of the values in a group.',
- 'mean': 'Aggregate function: returns the average of the values in a group.',
- 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
- }
-
- for _name, _doc in DSLS.items():
- locals()[_name] = _aggregate_func(_name, _doc)
- del _name, _doc
-
- @staticmethod
- def countDistinct(col, *cols):
- """ Return a new Column for distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
- [Row(c=2)]
-
- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
- @staticmethod
- def approxCountDistinct(col, rsd=None):
- """ Return a new Column for approxiate distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
- else:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
- return Column(jc)
-
- @staticmethod
- def udf(f, returnType=StringType()):
- """Create a user defined function (UDF)
-
- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
-
-
def _test():
import doctest
from pyspark.context import SparkContext
@@ -1059,11 +942,9 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
- rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
- globs['df'] = sqlCtx.inferSchema(rdd2)
- globs['df2'] = sqlCtx.inferSchema(rdd3)
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
new file mode 100644
index 0000000000..39aa550eeb
--- /dev/null
+++ b/python/pyspark/sql/functions.py
@@ -0,0 +1,170 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin functions
+"""
+
+from itertools import imap
+
+from py4j.java_collections import ListConverter
+
+from pyspark import SparkContext
+from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql.types import StringType
+from pyspark.sql.dataframe import Column, _to_java_column
+
+
+__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
+
+
+def _create_function(name, doc=""):
+ """ Create a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return _
+
+
+_functions = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+}
+
+
+for _name, _doc in _functions.items():
+ globals()[_name] = _create_function(_name, _doc)
+del _name, _doc
+__all__ += _functions.keys()
+
+
+def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of `col` or `cols`
+
+ >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
+ [Row(c=2)]
+
+ >>> df.agg(countDistinct("age", "name").alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
+ jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approximate distinct count of `col`
+
+ >>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+
+class UserDefinedFunction(object):
+ """
+ User defined function in Python
+ """
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.dataframe
+ globs = pyspark.sql.dataframe.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.dataframe, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 43e5c3a1b0..aa80bca346 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -96,7 +96,7 @@ class SQLTests(ReusedPySparkTestCase):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.createDataFrame(rdd)
+ cls.df = rdd.toDF()
@classmethod
def tearDownClass(cls):
@@ -138,7 +138,7 @@ class SQLTests(ReusedPySparkTestCase):
df = self.sqlCtx.jsonRDD(rdd)
df.count()
df.collect()
- df.schema()
+ df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
@@ -155,11 +155,11 @@ class SQLTests(ReusedPySparkTestCase):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.createDataFrame(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
@@ -195,7 +195,7 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1, result.head()[0])
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
- self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
@@ -204,8 +204,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -213,8 +212,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
+ df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -223,9 +221,8 @@ class SQLTests(ReusedPySparkTestCase):
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.createDataFrame(rdd)
- schema = df.schema()
+ df = self.sc.parallelize([row]).toDF()
+ schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
@@ -238,15 +235,14 @@ class SQLTests(ReusedPySparkTestCase):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.createDataFrame(rdd, schema)
+ df = rdd.toDF(schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.createDataFrame(rdd)
+ df0 = self.sc.parallelize([row]).toDF()
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
@@ -280,10 +276,11 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- from pyspark.sql import Dsl
- self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
- self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ from pyspark.sql import functions
+ self.assertEqual((0, u'99'),
+ tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
def test_save_and_load(self):
df = self.df
@@ -339,8 +336,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
- rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
@classmethod
def tearDownClass(cls):