From d66642e3978a76977414c2fdaedebaad35662667 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 25 May 2014 01:44:49 -0700 Subject: SPARK-1822: Some minor cleanup work on SchemaRDD.count() Minor cleanup following #841. Author: Reynold Xin Closes #868 from rxin/schema-count and squashes the following commits: 5442651 [Reynold Xin] SPARK-1822: Some minor cleanup work on SchemaRDD.count() --- python/pyspark/sql.py | 5 ++++- sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 8 ++++---- sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/TestData.scala | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f2001afae4..fa4b9c7b68 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -323,7 +323,10 @@ class SchemaRDD(RDD): def count(self): """ - Return the number of elements in this RDD. + Return the number of elements in this RDD. Unlike the base RDD + implementation of count, this implementation leverages the query + optimizer to compute the count on the SchemaRDD, which supports + features such as filter pushdown. >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.count() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 452da3d023..9883ebc0b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -276,12 +276,12 @@ class SchemaRDD( /** * :: Experimental :: - * Overriding base RDD implementation to leverage query optimizer + * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this + * implementation leverages the query optimizer to compute the count on the SchemaRDD, which + * supports features such as filter pushdown. */ @Experimental - override def count(): Long = { - groupBy()(Count(Literal(1))).collect().head.getLong(0) - } + override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0) /** * :: Experimental :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 233132a2fe..94ba13b14b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -124,7 +124,7 @@ class DslQuerySuite extends QueryTest { } test("zero count") { - assert(testData4.count() === 0) + assert(emptyTableData.count() === 0) } test("inner join where, one match per row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index b1eecb4dd3..944f520e43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,7 +47,7 @@ object TestData { (1, null) :: (2, 2) :: Nil) - val testData4 = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation('a.int, 'b.int) case class UpperCaseData(N: Int, L: String) val upperCaseData = -- cgit v1.2.3