aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08 /python/pyspark/sql/dataframe.py
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py221
1 files changed, 51 insertions, 170 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b6f052ee44..1438fe5285 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -21,21 +21,19 @@ import warnings
import random
import os
from tempfile import NamedTemporaryFile
-from itertools import imap
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _prepare_for_python_RDD
-from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
- UTF8Deserializer
+from pyspark.rdd import RDD
+from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
-__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]
class DataFrame(object):
@@ -76,6 +74,7 @@ class DataFrame(object):
self.sql_ctx = sql_ctx
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
+ self._schema = None # initialized lazily
@property
def rdd(self):
@@ -86,7 +85,7 @@ class DataFrame(object):
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
- schema = self.schema()
+ schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
@@ -216,14 +215,17 @@ class DataFrame(object):
self._sc._gateway._gateway_client)
self._jdf.save(source, jmode, joptions)
+ @property
def schema(self):
"""Returns the schema of this DataFrame (represented by
a L{StructType}).
- >>> df.schema()
+ >>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
- return _parse_datatype_json_string(self._jdf.schema().json())
+ if self._schema is None:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ return self._schema
def printSchema(self):
"""Prints out the schema in the tree format.
@@ -284,7 +286,7 @@ class DataFrame(object):
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
- cls = _create_cls(self.schema())
+ cls = _create_cls(self.schema)
return [cls(r) for r in rs]
def limit(self, num):
@@ -310,14 +312,26 @@ class DataFrame(object):
return self.limit(num).collect()
def map(self, f):
- """ Return a new RDD by applying a function to each Row, it's a
- shorthand for df.rdd.map()
+ """ Return a new RDD by applying a function to each Row
+
+ It's a shorthand for df.rdd.map()
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
"""
return self.rdd.map(f)
+ def flatMap(self, f):
+ """ Return a new RDD by first applying a function to all elements of this,
+ and then flattening the results.
+
+ It's a shorthand for df.rdd.flatMap()
+
+ >>> df.flatMap(lambda p: p.name).collect()
+ [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
+ """
+ return self.rdd.flatMap(f)
+
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.
@@ -378,21 +392,6 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
- # def takeSample(self, withReplacement, num, seed=None):
- # """Return a fixed-size sampled subset of this DataFrame.
- #
- # >>> df = sqlCtx.inferSchema(rdd)
- # >>> df.takeSample(False, 2, 97)
- # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
- # """
- # seed = seed if seed is not None else random.randint(0, sys.maxint)
- # with SCCallSiteSync(self.context) as css:
- # bytesInJava = self._jdf \
- # .takeSampleToPython(withReplacement, num, long(seed)) \
- # .iterator()
- # cls = _create_cls(self.schema())
- # return map(cls, self._collect_iterator_through_file(bytesInJava))
-
@property
def dtypes(self):
"""Return all column names and their data types as a list.
@@ -400,7 +399,7 @@ class DataFrame(object):
>>> df.dtypes
[('age', 'int'), ('name', 'string')]
"""
- return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
def columns(self):
@@ -409,7 +408,7 @@ class DataFrame(object):
>>> df.columns
[u'age', u'name']
"""
- return [f.name for f in self.schema().fields]
+ return [f.name for f in self.schema.fields]
def join(self, other, joinExprs=None, joinType=None):
"""
@@ -586,8 +585,8 @@ class DataFrame(object):
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.min(df.age)).collect()
+ >>> from pyspark.sql import functions as F
+ >>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
@@ -616,18 +615,18 @@ class DataFrame(object):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
- def addColumn(self, colName, col):
+ def withColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
- >>> df.addColumn('age2', df.age + 2).collect()
+ >>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))
- def renameColumn(self, existing, new):
+ def withColumnRenamed(self, existing, new):
""" Rename an existing column to a new name
- >>> df.renameColumn('age', 'age2').collect()
+ >>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
@@ -635,11 +634,11 @@ class DataFrame(object):
for c in self.columns]
return self.select(*cols)
- def to_pandas(self):
+ def toPandas(self):
"""
Collect all the rows and return a `pandas.DataFrame`.
- >>> df.to_pandas() # doctest: +SKIP
+ >>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
@@ -687,10 +686,11 @@ class GroupedData(object):
name to aggregate methods.
>>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"age": "max"}).collect()
- [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
- >>> from pyspark.sql import Dsl
- >>> gdf.agg(Dsl.min(df.age)).collect()
+ >>> gdf.agg({"*": "count"}).collect()
+ [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+
+ >>> from pyspark.sql import functions as F
+ >>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
@@ -742,12 +742,12 @@ class GroupedData(object):
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.lit(literal)
+ return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.col(name)
+ return sc._jvm.functions.col(name)
def _to_java_column(col):
@@ -767,9 +767,9 @@ def _unary_op(name, doc="unary operator"):
return _
-def _dsl_op(name, doc=''):
+def _func_op(name, doc=''):
def _(self):
- jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+ jc = getattr(self._sc._jvm.functions, name)(self._jc)
return Column(jc, self.sql_ctx)
_.__doc__ = doc
return _
@@ -818,7 +818,7 @@ class Column(DataFrame):
super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
- __neg__ = _dsl_op("negate")
+ __neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
@@ -842,7 +842,7 @@ class Column(DataFrame):
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
- __invert__ = _dsl_op('not')
+ __invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
@@ -920,11 +920,11 @@ class Column(DataFrame):
else:
return 'Column<%s>' % self._jdf.toString()
- def to_pandas(self):
+ def toPandas(self):
"""
Return a pandas.Series from the column
- >>> df.age.to_pandas() # doctest: +SKIP
+ >>> df.age.toPandas() # doctest: +SKIP
0 2
1 5
dtype: int64
@@ -934,123 +934,6 @@ class Column(DataFrame):
return pd.Series(data)
-def _aggregate_func(name, doc=""):
- """ Create a function for aggregator by name"""
- def _(col):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
- return Column(jc)
- _.__name__ = name
- _.__doc__ = doc
- return staticmethod(_)
-
-
-class UserDefinedFunction(object):
- def __init__(self, func, returnType):
- self.func = func
- self.returnType = returnType
- self._broadcast = None
- self._judf = self._create_judf()
-
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
- sc._javaAccumulator, jdt)
- return judf
-
- def __del__(self):
- if self._broadcast is not None:
- self._broadcast.unpersist()
- self._broadcast = None
-
- def __call__(self, *cols):
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
-
-class Dsl(object):
- """
- A collections of builtin aggregators
- """
- DSLS = {
- 'lit': 'Creates a :class:`Column` of literal value.',
- 'col': 'Returns a :class:`Column` based on the given column name.',
- 'column': 'Returns a :class:`Column` based on the given column name.',
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
- 'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
-
- 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
- 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
- 'count': 'Aggregate function: returns the number of items in a group.',
- 'sum': 'Aggregate function: returns the sum of all values in the expression.',
- 'avg': 'Aggregate function: returns the average of the values in a group.',
- 'mean': 'Aggregate function: returns the average of the values in a group.',
- 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
- }
-
- for _name, _doc in DSLS.items():
- locals()[_name] = _aggregate_func(_name, _doc)
- del _name, _doc
-
- @staticmethod
- def countDistinct(col, *cols):
- """ Return a new Column for distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
- [Row(c=2)]
-
- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
- @staticmethod
- def approxCountDistinct(col, rsd=None):
- """ Return a new Column for approxiate distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
- else:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
- return Column(jc)
-
- @staticmethod
- def udf(f, returnType=StringType()):
- """Create a user defined function (UDF)
-
- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
-
-
def _test():
import doctest
from pyspark.context import SparkContext
@@ -1059,11 +942,9 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
- rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
- globs['df'] = sqlCtx.inferSchema(rdd2)
- globs['df2'] = sqlCtx.inferSchema(rdd3)
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)