diff options
author | Kan Zhang <kzhang@apache.org> | 2014-05-25 00:06:42 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-25 00:06:42 -0700 |
commit | 6052db9dc10c996215658485e805200e4f0cf549 (patch) | |
tree | c5575603da300b1dbdcce95f7bdff9457bc26094 /sql/core | |
parent | 6e9fb6320bec3371bc9c010ccbc1b915f500486b (diff) | |
download | spark-6052db9dc10c996215658485e805200e4f0cf549.tar.gz spark-6052db9dc10c996215658485e805200e4f0cf549.tar.bz2 spark-6052db9dc10c996215658485e805200e4f0cf549.zip |
[SPARK-1822] SchemaRDD.count() should use query optimizer
Author: Kan Zhang <kzhang@apache.org>
Closes #841 from kanzhang/SPARK-1822 and squashes the following commits:
2f8072a [Kan Zhang] [SPARK-1822] Minor style update
cf4baa4 [Kan Zhang] [SPARK-1822] Adding Scaladoc
e67c910 [Kan Zhang] [SPARK-1822] SchemaRDD.count() should use optimizer
Diffstat (limited to 'sql/core')
3 files changed, 16 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 2569815ebb..452da3d023 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -276,6 +276,15 @@ class SchemaRDD( /** * :: Experimental :: + * Overriding base RDD implementation to leverage query optimizer + */ + @Experimental + override def count(): Long = { + groupBy()(Count(Literal(1))).collect().head.getLong(0) + } + + /** + * :: Experimental :: * Applies the given Generator, or table generating function, to this relation. * * @param generator A table generating function. The API for such functions is likely to change diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index f43e98d614..233132a2fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -108,10 +108,7 @@ class DslQuerySuite extends QueryTest { } test("count") { - checkAnswer( - testData2.groupBy()(Count(1)), - testData2.count() - ) + assert(testData2.count() === testData2.map(_ => 1).count()) } test("null count") { @@ -126,6 +123,10 @@ class DslQuerySuite extends QueryTest { ) } + test("zero count") { + assert(testData4.count() === 0) + } + test("inner join where, one match per row") { checkAnswer( upperCaseData.join(lowerCaseData, Inner).where('n === 'N), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 1aca387252..b1eecb4dd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,6 +47,8 @@ object TestData { (1, null) :: (2, 2) :: Nil) + val testData4 = logical.LocalRelation('a.int, 'b.int) + case class UpperCaseData(N: Int, L: String) val upperCaseData = TestSQLContext.sparkContext.parallelize( |