aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKan Zhang <kzhang@apache.org>2014-05-25 00:06:42 -0700
committerReynold Xin <rxin@apache.org>2014-05-25 00:06:42 -0700
commit6052db9dc10c996215658485e805200e4f0cf549 (patch)
treec5575603da300b1dbdcce95f7bdff9457bc26094
parent6e9fb6320bec3371bc9c010ccbc1b915f500486b (diff)
downloadspark-6052db9dc10c996215658485e805200e4f0cf549.tar.gz
spark-6052db9dc10c996215658485e805200e4f0cf549.tar.bz2
spark-6052db9dc10c996215658485e805200e4f0cf549.zip
[SPARK-1822] SchemaRDD.count() should use query optimizer
Author: Kan Zhang <kzhang@apache.org> Closes #841 from kanzhang/SPARK-1822 and squashes the following commits: 2f8072a [Kan Zhang] [SPARK-1822] Minor style update cf4baa4 [Kan Zhang] [SPARK-1822] Adding Scaladoc e67c910 [Kan Zhang] [SPARK-1822] SchemaRDD.count() should use optimizer
-rw-r--r--python/pyspark/sql.py14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala2
5 files changed, 32 insertions, 8 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index bbe69e7d8f..f2001afae4 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -268,7 +268,7 @@ class SchemaRDD(RDD):
def _jrdd(self):
"""
Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the
- L{pyspark.rdd.RDD} super class (map, count, etc.).
+ L{pyspark.rdd.RDD} super class (map, filter, etc.).
"""
if not hasattr(self, '_lazy_jrdd'):
self._lazy_jrdd = self._toPython()._jrdd
@@ -321,6 +321,18 @@ class SchemaRDD(RDD):
"""
self._jschema_rdd.saveAsTable(tableName)
+ def count(self):
+ """
+ Return the number of elements in this RDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.count()
+ 3L
+ >>> srdd.count() == srdd.map(lambda x: x).count()
+ True
+ """
+ return self._jschema_rdd.count()
+
def _toPython(self):
# We have to import the Row class explicitly, so that the reference Pickler has is
# pyspark.sql.Row instead of __main__.Row
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 5dbaaa3b0c..1bcd4e2276 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -151,7 +151,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
- override def dataType = IntegerType
+ override def dataType = LongType
override def toString = s"COUNT($child)"
override def asPartial: SplitEvaluation = {
@@ -295,12 +295,12 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- var count: Int = _
+ var count: Long = _
override def update(input: Row): Unit = {
val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
- count += 1
+ count += 1L
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 2569815ebb..452da3d023 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -276,6 +276,15 @@ class SchemaRDD(
/**
* :: Experimental ::
+ * Overriding base RDD implementation to leverage query optimizer
+ */
+ @Experimental
+ override def count(): Long = {
+ groupBy()(Count(Literal(1))).collect().head.getLong(0)
+ }
+
+ /**
+ * :: Experimental ::
* Applies the given Generator, or table generating function, to this relation.
*
* @param generator A table generating function. The API for such functions is likely to change
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index f43e98d614..233132a2fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -108,10 +108,7 @@ class DslQuerySuite extends QueryTest {
}
test("count") {
- checkAnswer(
- testData2.groupBy()(Count(1)),
- testData2.count()
- )
+ assert(testData2.count() === testData2.map(_ => 1).count())
}
test("null count") {
@@ -126,6 +123,10 @@ class DslQuerySuite extends QueryTest {
)
}
+ test("zero count") {
+ assert(testData4.count() === 0)
+ }
+
test("inner join where, one match per row") {
checkAnswer(
upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 1aca387252..b1eecb4dd3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -47,6 +47,8 @@ object TestData {
(1, null) ::
(2, 2) :: Nil)
+ val testData4 = logical.LocalRelation('a.int, 'b.int)
+
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(