aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala2
4 files changed, 10 insertions, 7 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f2001afae4..fa4b9c7b68 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -323,7 +323,10 @@ class SchemaRDD(RDD):
def count(self):
"""
- Return the number of elements in this RDD.
+ Return the number of elements in this RDD. Unlike the base RDD
+ implementation of count, this implementation leverages the query
+ optimizer to compute the count on the SchemaRDD, which supports
+ features such as filter pushdown.
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.count()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 452da3d023..9883ebc0b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -276,12 +276,12 @@ class SchemaRDD(
/**
* :: Experimental ::
- * Overriding base RDD implementation to leverage query optimizer
+ * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this
+ * implementation leverages the query optimizer to compute the count on the SchemaRDD, which
+ * supports features such as filter pushdown.
*/
@Experimental
- override def count(): Long = {
- groupBy()(Count(Literal(1))).collect().head.getLong(0)
- }
+ override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
/**
* :: Experimental ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 233132a2fe..94ba13b14b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -124,7 +124,7 @@ class DslQuerySuite extends QueryTest {
}
test("zero count") {
- assert(testData4.count() === 0)
+ assert(emptyTableData.count() === 0)
}
test("inner join where, one match per row") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index b1eecb4dd3..944f520e43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -47,7 +47,7 @@ object TestData {
(1, null) ::
(2, 2) :: Nil)
- val testData4 = logical.LocalRelation('a.int, 'b.int)
+ val emptyTableData = logical.LocalRelation('a.int, 'b.int)
case class UpperCaseData(N: Int, L: String)
val upperCaseData =