From ac0b2b788ff144970d6fdbdc445367772770458d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 4 Feb 2015 11:34:46 -0800 Subject: [SPARK-5588] [SQL] support select/filter by SQL expression ``` df.selectExpr('a + 1', 'abs(age)') df.filter('age > 3') df[ df.age > 3 ] df[ ['age', 'name'] ] ``` Author: Davies Liu Closes #4359 from davies/select_expr and squashes the following commits: d99856b [Davies Liu] support select/filter by SQL expression --- python/pyspark/sql.py | 53 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 10 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 74305dea74..a266cde51d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2128,7 +2128,7 @@ class DataFrame(object): raise ValueError("should sort by at least one column") jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) sortBy = sort @@ -2159,13 +2159,20 @@ class DataFrame(object): >>> df['age'].collect() [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc, self.sql_ctx) - - # TODO projection - raise IndexError + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): """ Return the column by given name @@ -2194,18 +2201,44 @@ class DataFrame(object): cols = ["*"] jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition. + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] """ - return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) where = filter @@ -2223,7 +2256,7 @@ class DataFrame(object): """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): @@ -2338,7 +2371,7 @@ class GroupedDataFrame(object): assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" jcols = ListConverter().convert([c._jc for c in exprs[1:]], self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -2633,7 +2666,7 @@ class Dsl(object): jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.Dsl.toColumns(jcols)) + sc._jvm.PythonUtils.toSeq(jcols)) return Column(jc) @staticmethod -- cgit v1.2.3