From 8ca3418e1b3e2687e75a08c185d17045a97279fb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 19 Feb 2015 12:09:44 -0800 Subject: [SPARK-5904][SQL] DataFrame API fixes. 1. Column is no longer a DataFrame to simplify class hierarchy. 2. Don't use varargs on abstract methods (see Scala compiler bug SI-9013). Author: Reynold Xin Closes #4686 from rxin/SPARK-5904 and squashes the following commits: fd9b199 [Reynold Xin] Fixed Python tests. df25cef [Reynold Xin] Non final. 5221530 [Reynold Xin] [SPARK-5904][SQL] DataFrame API fixes. --- python/pyspark/sql/dataframe.py | 56 +++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 36 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c68c97e926..010c38f93b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -546,7 +546,7 @@ class DataFrame(object): def __getitem__(self, item): """ Return the column by given name - >>> df['age'].collect() + >>> df.select(df['age']).collect() [Row(age=2), Row(age=5)] >>> df[ ["name", "age"]].collect() [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] @@ -555,7 +555,7 @@ class DataFrame(object): """ if isinstance(item, basestring): jc = self._jdf.apply(item) - return Column(jc, self.sql_ctx) + return Column(jc) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, list): @@ -566,13 +566,13 @@ class DataFrame(object): def __getattr__(self, name): """ Return the column by given name - >>> df.age.collect() + >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ if name.startswith("__"): raise AttributeError(name) jc = self._jdf.apply(name) - return Column(jc, self.sql_ctx) + return Column(jc) def select(self, *cols): """ Selecting a set of expressions. @@ -698,7 +698,7 @@ class DataFrame(object): >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ - cols = [Column(_to_java_column(c), self.sql_ctx).alias(new) + cols = [Column(_to_java_column(c)).alias(new) if c == existing else c for c in self.columns] return self.select(*cols) @@ -873,15 +873,16 @@ def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): jc = getattr(self._jc, name)() - return Column(jc, self.sql_ctx) + return Column(jc) _.__doc__ = doc return _ def _func_op(name, doc=''): def _(self): - jc = getattr(self._sc._jvm.functions, name)(self._jc) - return Column(jc, self.sql_ctx) + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) _.__doc__ = doc return _ @@ -892,7 +893,7 @@ def _bin_op(name, doc="binary operator"): def _(self, other): jc = other._jc if isinstance(other, Column) else other njc = getattr(self._jc, name)(jc) - return Column(njc, self.sql_ctx) + return Column(njc) _.__doc__ = doc return _ @@ -903,12 +904,12 @@ def _reverse_op(name, doc="binary operator"): def _(self, other): jother = _create_column_from_literal(other) jc = getattr(jother, name)(self._jc) - return Column(jc, self.sql_ctx) + return Column(jc) _.__doc__ = doc return _ -class Column(DataFrame): +class Column(object): """ A column in a DataFrame. @@ -924,9 +925,8 @@ class Column(DataFrame): 1 / df.colName """ - def __init__(self, jc, sql_ctx=None): + def __init__(self, jc): self._jc = jc - super(Column, self).__init__(jc, sql_ctx) # arithmetic operators __neg__ = _func_op("negate") @@ -975,7 +975,7 @@ class Column(DataFrame): :param startPos: start position (int or Column) :param length: length of the substring (int or Column) - >>> df.name.substr(1, 3).collect() + >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col=u'Ali'), Row(col=u'Bob')] """ if type(startPos) != type(length): @@ -986,7 +986,7 @@ class Column(DataFrame): jc = self._jc.substr(startPos._jc, length._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc, self.sql_ctx) + return Column(jc) __getslice__ = substr @@ -1000,10 +1000,10 @@ class Column(DataFrame): def alias(self, alias): """Return a alias for this column - >>> df.age.alias("age2").collect() + >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ - return Column(getattr(self._jc, "as")(alias), self.sql_ctx) + return Column(getattr(self._jc, "as")(alias)) def cast(self, dataType): """ Convert the column into type `dataType` @@ -1013,34 +1013,18 @@ class Column(DataFrame): >>> df.select(df.age.cast(StringType()).alias('ages')).collect() [Row(ages=u'2'), Row(ages=u'5')] """ - if self.sql_ctx is None: - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - else: - ssql_ctx = self.sql_ctx._ssql_ctx if isinstance(dataType, basestring): jc = self._jc.cast(dataType) elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) - return Column(jc, self.sql_ctx) + return Column(jc) def __repr__(self): return 'Column<%s>' % self._jdf.toString().encode('utf8') - def toPandas(self): - """ - Return a pandas.Series from the column - - >>> df.age.toPandas() # doctest: +SKIP - 0 2 - 1 5 - dtype: int64 - """ - import pandas as pd - data = [c for c, in self.collect()] - return pd.Series(data) - def _test(): import doctest -- cgit v1.2.3