aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala104
1 files changed, 93 insertions, 11 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4b35c8fd83..7b5aa4763f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
+import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterAll
-import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
@@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Nil)
}
+ test("null literal") {
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | AVG(null),
+ | COUNT(null),
+ | FIRST(null),
+ | LAST(null),
+ | MAX(null),
+ | MIN(null),
+ | SUM(null)
+ """.stripMargin),
+ Row(null, 0, null, null, null, null, null) :: Nil)
+ }
+
test("only do grouping") {
checkAnswer(
sqlContext.sql(
@@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
|SELECT avg(value) FROM agg1
""".stripMargin),
Row(11.125) :: Nil)
-
- checkAnswer(
- sqlContext.sql(
- """
- |SELECT avg(null)
- """.stripMargin),
- Row(null) :: Nil)
}
test("udaf") {
@@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
| max(distinct value1)
|FROM agg2
""".stripMargin),
- Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+ Row(-60, 70.0, 101.0/9.0, 5.6, 100))
checkAnswer(
sqlContext.sql(
@@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
Row(3, null, 3.0, null, null, null) ::
Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT
+ | count(value1),
+ | count(*),
+ | count(1),
+ | count(DISTINCT value1),
+ | key
+ |FROM agg2
+ |GROUP BY key
+ """.stripMargin),
+ Row(3, 3, 3, 2, 1) ::
+ Row(3, 4, 4, 2, 2) ::
+ Row(0, 2, 2, 0, 3) ::
+ Row(3, 4, 4, 3, null) :: Nil)
}
test("test count") {
@@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
|FROM agg1
|GROUP BY key
""".stripMargin).queryExecution.executedPlan.collect {
- case agg: aggregate.Aggregate => agg
+ case agg: aggregate.SortBasedAggregate => agg
+ case agg: aggregate.TungstenAggregate => agg
}
val message =
"We should fallback to the old aggregation code path if " +
@@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite {
sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
}
}
+
+class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
+
+ var originalUnsafeEnabled: Boolean = _
+
+ override def beforeAll(): Unit = {
+ originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+ sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt")
+ }
+
+ override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ (0 to 2).foreach { fallbackStartsAt =>
+ sqlContext.setConf(
+ "spark.sql.TungstenAggregate.testFallbackStartsAt",
+ fallbackStartsAt.toString)
+
+ // Create a new df to make sure its physical operator picks up
+ // spark.sql.TungstenAggregate.testFallbackStartsAt.
+ val newActual = DataFrame(sqlContext, actual.logicalPlan)
+
+ QueryTest.checkAnswer(newActual, expectedAnswer) match {
+ case Some(errorMessage) =>
+ val newErrorMessage =
+ s"""
+ |The following aggregation query failed when using TungstenAggregate with
+ |controlled fallback (it falls back to sort-based aggregation once it has processed
+ |$fallbackStartsAt input rows). The query is
+ |${actual.queryExecution}
+ |
+ |$errorMessage
+ """.stripMargin
+
+ fail(newErrorMessage)
+ case None =>
+ }
+ }
+ }
+
+ // Override it to make sure we call the actually overridden checkAnswer.
+ override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(df, Seq(expectedAnswer))
+ }
+
+ // Override it to make sure we call the actually overridden checkAnswer.
+ override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+ checkAnswer(df, expectedAnswer.collect())
+ }
+}