aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala11
-rw-r--r--python/pyspark/sql.py53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala11
3 files changed, 53 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index b7cfc8bd9c..acbaba6791 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -17,8 +17,10 @@
package org.apache.spark.api.python
-import java.io.{File, InputStream, IOException, OutputStream}
+import java.io.{File}
+import java.util.{List => JList}
+import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext
@@ -44,4 +46,11 @@ private[spark] object PythonUtils {
def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
sc.parallelize(List("a", null, "b"))
}
+
+ /**
+ * Convert list of T into seq of T (for calling API with varargs)
+ */
+ def toSeq[T](cols: JList[T]): Seq[T] = {
+ cols.toList.toSeq
+ }
}
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 74305dea74..a266cde51d 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2128,7 +2128,7 @@ class DataFrame(object):
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)
sortBy = sort
@@ -2159,13 +2159,20 @@ class DataFrame(object):
>>> df['age'].collect()
[Row(age=2), Row(age=5)]
+ >>> df[ ["name", "age"]].collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df[ df.age > 3 ].collect()
+ [Row(age=5, name=u'Bob')]
"""
if isinstance(item, basestring):
jc = self._jdf.apply(item)
return Column(jc, self.sql_ctx)
-
- # TODO projection
- raise IndexError
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, list):
+ return self.select(*item)
+ else:
+ raise IndexError("unexpected index: %s" % item)
def __getattr__(self, name):
""" Return the column by given name
@@ -2194,18 +2201,44 @@ class DataFrame(object):
cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def selectExpr(self, *expr):
+ """
+ Selects a set of SQL expressions. This is a variant of
+ `select` that accepts SQL expressions.
+
+ >>> df.selectExpr("age * 2", "abs(age)").collect()
+ [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ """
+ jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
+ jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
return DataFrame(jdf, self.sql_ctx)
def filter(self, condition):
- """ Filtering rows using the given condition.
+ """ Filtering rows using the given condition, which could be
+ Column expression or string of SQL expression.
+
+ where() is an alias for filter().
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
[Row(age=2, name=u'Alice')]
+
+ >>> df.filter("age > 3").collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where("age = 2").collect()
+ [Row(age=2, name=u'Alice')]
"""
- return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+ if isinstance(condition, basestring):
+ jdf = self._jdf.filter(condition)
+ elif isinstance(condition, Column):
+ jdf = self._jdf.filter(condition._jc)
+ else:
+ raise TypeError("condition should be string or Column")
+ return DataFrame(jdf, self.sql_ctx)
where = filter
@@ -2223,7 +2256,7 @@ class DataFrame(object):
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return GroupedDataFrame(jdf, self.sql_ctx)
def agg(self, *exprs):
@@ -2338,7 +2371,7 @@ class GroupedDataFrame(object):
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@@ -2633,7 +2666,7 @@ class Dsl(object):
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.Dsl.toColumns(jcols))
+ sc._jvm.PythonUtils.toSeq(jcols))
return Column(jc)
@staticmethod
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index 8cf59f0a1f..50f442dd87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -17,11 +17,8 @@
package org.apache.spark.sql
-import java.util.{List => JList}
-
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
-import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
@@ -169,14 +166,6 @@ object Dsl {
/** Computes the absolutle value. */
def abs(e: Column): Column = Abs(e.expr)
- /**
- * This is a private API for Python
- * TODO: move this to a private package
- */
- def toColumns(cols: JList[Column]): Seq[Column] = {
- cols.toList.toSeq
- }
-
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////