aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py5
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala12
6 files changed, 67 insertions, 13 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 268c7ef97c..74305dea74 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2126,10 +2126,9 @@ class DataFrame(object):
"""
if not cols:
raise ValueError("should sort by at least one column")
- jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
- jdf = self._jdf.sort(_to_java_column(cols[0]),
- self._sc._jvm.Dsl.toColumns(jcols))
+ jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)
sortBy = sort
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 5c006e9d4c..a9bd079c70 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -36,6 +36,16 @@ import org.apache.spark.sql.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends AbstractSparkSQLParser {
+
+ def parseExpression(input: String): Expression = {
+ // Initialize the Keywords.
+ lexical.initialize(reservedWords)
+ phrase(expression)(new lexical.Scanner(input)) match {
+ case Success(plan, _) => plan
+ case failureOrError => sys.error(failureOrError.toString)
+ }
+ }
+
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ABS = Keyword("ABS")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 732b685558..a4997fb293 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -173,7 +173,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
- def sort(sortExpr: Column, sortExprs: Column*): DataFrame
+ def sort(sortExprs: Column*): DataFrame
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
@@ -187,7 +187,7 @@ trait DataFrame extends RDDApi[Row] {
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
- def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
+ def orderBy(sortExprs: Column*): DataFrame
/**
* Selects column based on the column name and return it as a [[Column]].
@@ -237,6 +237,17 @@ trait DataFrame extends RDDApi[Row] {
def select(col: String, cols: String*): DataFrame
/**
+ * Selects a set of SQL expressions. This is a variant of `select` that accepts
+ * SQL expressions.
+ *
+ * {{{
+ * df.selectExpr("colA", "colB as newName", "abs(colC)")
+ * }}}
+ */
+ @scala.annotation.varargs
+ def selectExpr(exprs: String*): DataFrame
+
+ /**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
@@ -248,6 +259,14 @@ trait DataFrame extends RDDApi[Row] {
def filter(condition: Column): DataFrame
/**
+ * Filters rows using the given SQL expression.
+ * {{{
+ * peopleDf.filter("age > 15")
+ * }}}
+ */
+ def filter(conditionExpr: String): DataFrame
+
+ /**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{
* // The following are equivalent:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index a52bfa59a1..c702adcb65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -124,11 +124,11 @@ private[sql] class DataFrameImpl protected[sql](
}
override def sort(sortCol: String, sortCols: String*): DataFrame = {
- orderBy(apply(sortCol), sortCols.map(apply) :_*)
+ sort((sortCol +: sortCols).map(apply) :_*)
}
- override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
- val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
+ override def sort(sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
@@ -143,8 +143,8 @@ private[sql] class DataFrameImpl protected[sql](
sort(sortCol, sortCols :_*)
}
- override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
- sort(sortExpr, sortExprs :_*)
+ override def orderBy(sortExprs: Column*): DataFrame = {
+ sort(sortExprs :_*)
}
override def col(colName: String): Column = colName match {
@@ -179,10 +179,20 @@ private[sql] class DataFrameImpl protected[sql](
select((col +: cols).map(Column(_)) :_*)
}
+ override def selectExpr(exprs: String*): DataFrame = {
+ select(exprs.map { expr =>
+ Column(new SqlParser().parseExpression(expr))
+ } :_*)
+ }
+
override def filter(condition: Column): DataFrame = {
Filter(condition.expr, logicalPlan)
}
+ override def filter(conditionExpr: String): DataFrame = {
+ filter(Column(new SqlParser().parseExpression(conditionExpr)))
+ }
+
override def where(condition: Column): DataFrame = {
filter(condition)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index ba5c7355b4..6b032d3d69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -66,11 +66,11 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def sort(sortCol: String, sortCols: String*): DataFrame = err()
- override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+ override def sort(sortExprs: Column*): DataFrame = err()
override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
- override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = err()
+ override def orderBy(sortExprs: Column*): DataFrame = err()
override def col(colName: String): Column = err()
@@ -80,8 +80,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def select(col: String, cols: String*): DataFrame = err()
+ override def selectExpr(exprs: String*): DataFrame = err()
+
override def filter(condition: Column): DataFrame = err()
+ override def filter(conditionExpr: String): DataFrame = err()
+
override def where(condition: Column): DataFrame = err()
override def apply(condition: Column): DataFrame = err()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 19d4f34e56..e588555ad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -47,6 +47,18 @@ class DataFrameSuite extends QueryTest {
testData.collect().toSeq)
}
+ test("selectExpr") {
+ checkAnswer(
+ testData.selectExpr("abs(key)", "value"),
+ testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
+ }
+
+ test("filterExpr") {
+ checkAnswer(
+ testData.filter("key > 90"),
+ testData.collect().filter(_.getInt(0) > 90).toSeq)
+ }
+
test("repartition") {
checkAnswer(
testData.select('key).repartition(10).select('key),