aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorXimo Guanter Gonzalbez <ximo@tid.es>2014-07-02 10:03:44 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-02 10:03:44 -0700
commit5c6ec94da1bacd8e65a43acb92b6721493484e7b (patch)
treefda1ef7aff92578679cb344ded16bc40930efe32 /sql/core
parent6596392da0fc0fee89e22adfca239a3477dfcbab (diff)
downloadspark-5c6ec94da1bacd8e65a43acb92b6721493484e7b.tar.gz
spark-5c6ec94da1bacd8e65a43acb92b6721493484e7b.tar.bz2
spark-5c6ec94da1bacd8e65a43acb92b6721493484e7b.zip
SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG
**Description** This patch enables using the `.select()` function in SchemaRDD with functions such as `Sum`, `Count` and other. **Testing** Unit tests added. Author: Ximo Guanter Gonzalbez <ximo@tid.es> Closes #1211 from edrevo/add-expression-support-in-select and squashes the following commits: fe4a1e1 [Ximo Guanter Gonzalbez] Extend SQL DSL to functions e1d344a [Ximo Guanter Gonzalbez] SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala32
2 files changed, 33 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 7c0efb4566..8f9f54f610 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -133,8 +133,13 @@ class SchemaRDD(
*
* @group Query
*/
- def select(exprs: NamedExpression*): SchemaRDD =
- new SchemaRDD(sqlContext, Project(exprs, logicalPlan))
+ def select(exprs: Expression*): SchemaRDD = {
+ val aliases = exprs.zipWithIndex.map {
+ case (ne: NamedExpression, _) => ne
+ case (e, i) => Alias(e, s"c$i")()
+ }
+ new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
+ }
/**
* Filters the output, only returning those rows where `condition` evaluates to true.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index e4a64a7a48..04ac008682 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest {
Seq(Seq("1")))
}
+ test("select with functions") {
+ checkAnswer(
+ testData.select(sum('value), avg('value), count(1)),
+ Seq(Seq(5050.0, 50.5, 100)))
+
+ checkAnswer(
+ testData2.select('a + 'b, 'a < 'b),
+ Seq(
+ Seq(2, false),
+ Seq(3, true),
+ Seq(3, false),
+ Seq(4, false),
+ Seq(4, false),
+ Seq(5, false)))
+
+ checkAnswer(
+ testData2.select(sumDistinct('a)),
+ Seq(Seq(6)))
+ }
+
test("sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
@@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest {
test("average") {
checkAnswer(
- testData2.groupBy()(Average('a)),
+ testData2.groupBy()(avg('a)),
2.0)
}
test("null average") {
checkAnswer(
- testData3.groupBy()(Average('b)),
+ testData3.groupBy()(avg('b)),
2.0)
checkAnswer(
- testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
+ testData3.groupBy()(avg('b), countDistinct('b)),
(2.0, 1) :: Nil)
}
@@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest {
test("null count") {
checkAnswer(
- testData3.groupBy('a)('a, Count('b)),
+ testData3.groupBy('a)('a, count('b)),
Seq((1,0), (2, 1))
)
checkAnswer(
- testData3.groupBy('a)('a, Count('a + 'b)),
+ testData3.groupBy('a)('a, count('a + 'b)),
Seq((1,0), (2, 1))
)
checkAnswer(
- testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
+ testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
(2, 1, 2, 2, 1) :: Nil
)
}