aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala48
3 files changed, 70 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index c069da016f..ecde9c5713 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -266,7 +266,18 @@ case class GeneratedAggregate(
val joinedRow = new JoinedRow3
- if (groupingExpressions.isEmpty) {
+ if (!iter.hasNext) {
+ // This is an empty input, so return early so that we do not allocate data structures
+ // that won't be cleaned up (see SPARK-8357).
+ if (groupingExpressions.isEmpty) {
+ // This is a global aggregate, so return an empty aggregation buffer.
+ val resultProjection = resultProjectionBuilder()
+ Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
+ } else {
+ // This is a grouped aggregate, so return an empty iterator.
+ Iterator[InternalRow]()
+ }
+ } else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: InternalRow = null
@@ -280,6 +291,7 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled) {
+ assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 61d5f2061a..beee10173f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(2, 1, 2, 2, 1))
}
+ test("count of empty table") {
+ withTempTable("t") {
+ Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t")
+ checkAnswer(
+ sql("select count(a) from t"),
+ Row(0))
+ }
+ }
+
test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala
new file mode 100644
index 0000000000..20def6bef0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+
+class AggregateSuite extends SparkPlanTest {
+
+ test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
+ val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
+ val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
+ try {
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+ TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+ val df = Seq.empty[(Int, Int)].toDF("a", "b")
+ checkAnswer(
+ df,
+ GeneratedAggregate(
+ partial = true,
+ Seq(df.col("b").expr),
+ Seq(Alias(Count(df.col("a").expr), "cnt")()),
+ unsafeEnabled = true,
+ _: SparkPlan),
+ Seq.empty
+ )
+ } finally {
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
+ TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
+ }
+ }
+}