aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-19 12:09:44 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-19 12:09:44 -0800
commit8ca3418e1b3e2687e75a08c185d17045a97279fb (patch)
treedf3de8114cedc6d02a4e656734c865ce5c1e1cb7 /python
parent94cdb05ff7e6b8fc5b3a574202ba8bc8e5bbe689 (diff)
downloadspark-8ca3418e1b3e2687e75a08c185d17045a97279fb.tar.gz
spark-8ca3418e1b3e2687e75a08c185d17045a97279fb.tar.bz2
spark-8ca3418e1b3e2687e75a08c185d17045a97279fb.zip
[SPARK-5904][SQL] DataFrame API fixes.
1. Column is no longer a DataFrame to simplify class hierarchy. 2. Don't use varargs on abstract methods (see Scala compiler bug SI-9013). Author: Reynold Xin <rxin@databricks.com> Closes #4686 from rxin/SPARK-5904 and squashes the following commits: fd9b199 [Reynold Xin] Fixed Python tests. df25cef [Reynold Xin] Non final. 5221530 [Reynold Xin] [SPARK-5904][SQL] DataFrame API fixes.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py56
1 files changed, 20 insertions, 36 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index c68c97e926..010c38f93b 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -546,7 +546,7 @@ class DataFrame(object):
def __getitem__(self, item):
""" Return the column by given name
- >>> df['age'].collect()
+ >>> df.select(df['age']).collect()
[Row(age=2), Row(age=5)]
>>> df[ ["name", "age"]].collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
@@ -555,7 +555,7 @@ class DataFrame(object):
"""
if isinstance(item, basestring):
jc = self._jdf.apply(item)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, list):
@@ -566,13 +566,13 @@ class DataFrame(object):
def __getattr__(self, name):
""" Return the column by given name
- >>> df.age.collect()
+ >>> df.select(df.age).collect()
[Row(age=2), Row(age=5)]
"""
if name.startswith("__"):
raise AttributeError(name)
jc = self._jdf.apply(name)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
def select(self, *cols):
""" Selecting a set of expressions.
@@ -698,7 +698,7 @@ class DataFrame(object):
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
- cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
+ cols = [Column(_to_java_column(c)).alias(new)
if c == existing else c
for c in self.columns]
return self.select(*cols)
@@ -873,15 +873,16 @@ def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, name)()
- return Column(jc, self.sql_ctx)
+ return Column(jc)
_.__doc__ = doc
return _
def _func_op(name, doc=''):
def _(self):
- jc = getattr(self._sc._jvm.functions, name)(self._jc)
- return Column(jc, self.sql_ctx)
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(self._jc)
+ return Column(jc)
_.__doc__ = doc
return _
@@ -892,7 +893,7 @@ def _bin_op(name, doc="binary operator"):
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, name)(jc)
- return Column(njc, self.sql_ctx)
+ return Column(njc)
_.__doc__ = doc
return _
@@ -903,12 +904,12 @@ def _reverse_op(name, doc="binary operator"):
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
_.__doc__ = doc
return _
-class Column(DataFrame):
+class Column(object):
"""
A column in a DataFrame.
@@ -924,9 +925,8 @@ class Column(DataFrame):
1 / df.colName
"""
- def __init__(self, jc, sql_ctx=None):
+ def __init__(self, jc):
self._jc = jc
- super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
__neg__ = _func_op("negate")
@@ -975,7 +975,7 @@ class Column(DataFrame):
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
- >>> df.name.substr(1, 3).collect()
+ >>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col=u'Ali'), Row(col=u'Bob')]
"""
if type(startPos) != type(length):
@@ -986,7 +986,7 @@ class Column(DataFrame):
jc = self._jc.substr(startPos._jc, length._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
- return Column(jc, self.sql_ctx)
+ return Column(jc)
__getslice__ = substr
@@ -1000,10 +1000,10 @@ class Column(DataFrame):
def alias(self, alias):
"""Return a alias for this column
- >>> df.age.alias("age2").collect()
+ >>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
- return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+ return Column(getattr(self._jc, "as")(alias))
def cast(self, dataType):
""" Convert the column into type `dataType`
@@ -1013,34 +1013,18 @@ class Column(DataFrame):
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
- if self.sql_ctx is None:
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- else:
- ssql_ctx = self.sql_ctx._ssql_ctx
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
def __repr__(self):
return 'Column<%s>' % self._jdf.toString().encode('utf8')
- def toPandas(self):
- """
- Return a pandas.Series from the column
-
- >>> df.age.toPandas() # doctest: +SKIP
- 0 2
- 1 5
- dtype: int64
- """
- import pandas as pd
- data = [c for c, in self.collect()]
- return pd.Series(data)
-
def _test():
import doctest