aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-04 11:34:46 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 11:34:46 -0800
commitac0b2b788ff144970d6fdbdc445367772770458d (patch)
tree65a7875644f6ef8617c6e53e413fbce321cbb33e /python
parent38a416f0360fa68fc445af14910fb253ff9ad493 (diff)
downloadspark-ac0b2b788ff144970d6fdbdc445367772770458d.tar.gz
spark-ac0b2b788ff144970d6fdbdc445367772770458d.tar.bz2
spark-ac0b2b788ff144970d6fdbdc445367772770458d.zip
[SPARK-5588] [SQL] support select/filter by SQL expression
``` df.selectExpr('a + 1', 'abs(age)') df.filter('age > 3') df[ df.age > 3 ] df[ ['age', 'name'] ] ``` Author: Davies Liu <davies@databricks.com> Closes #4359 from davies/select_expr and squashes the following commits: d99856b [Davies Liu] support select/filter by SQL expression
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py53
1 files changed, 43 insertions, 10 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 74305dea74..a266cde51d 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2128,7 +2128,7 @@ class DataFrame(object):
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)
sortBy = sort
@@ -2159,13 +2159,20 @@ class DataFrame(object):
>>> df['age'].collect()
[Row(age=2), Row(age=5)]
+ >>> df[ ["name", "age"]].collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df[ df.age > 3 ].collect()
+ [Row(age=5, name=u'Bob')]
"""
if isinstance(item, basestring):
jc = self._jdf.apply(item)
return Column(jc, self.sql_ctx)
-
- # TODO projection
- raise IndexError
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, list):
+ return self.select(*item)
+ else:
+ raise IndexError("unexpected index: %s" % item)
def __getattr__(self, name):
""" Return the column by given name
@@ -2194,18 +2201,44 @@ class DataFrame(object):
cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def selectExpr(self, *expr):
+ """
+ Selects a set of SQL expressions. This is a variant of
+ `select` that accepts SQL expressions.
+
+ >>> df.selectExpr("age * 2", "abs(age)").collect()
+ [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ """
+ jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
+ jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
- """ Filtering rows using the given condition.
+ """ Filtering rows using the given condition, which could be
+ Column expression or string of SQL expression.
+
+ where() is an alias for filter().
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
[Row(age=2, name=u'Alice')]
+
+ >>> df.filter("age > 3").collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where("age = 2").collect()
+ [Row(age=2, name=u'Alice')]
"""
- return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+ if isinstance(condition, basestring):
+ jdf = self._jdf.filter(condition)
+ elif isinstance(condition, Column):
+ jdf = self._jdf.filter(condition._jc)
+ else:
+ raise TypeError("condition should be string or Column")
+ return DataFrame(jdf, self.sql_ctx)
where = filter
@@ -2223,7 +2256,7 @@ class DataFrame(object):
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
def agg(self, *exprs):
@@ -2338,7 +2371,7 @@ class GroupedDataFrame(object):
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@@ -2633,7 +2666,7 @@ class Dsl(object):
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.Dsl.toColumns(jcols))
+ sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)
@staticmethod