aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-04-15 13:15:58 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-15 13:15:58 -0700
commitcf38fe04f8782ff4573ae106ec0de8e8d183cb2b (patch)
treef8637a5e836799c45ea2b719973ad0696edb30d0 /sql
parent85842760dc4616577162f44cc0fa9db9bd23bd9c (diff)
downloadspark-cf38fe04f8782ff4573ae106ec0de8e8d183cb2b.tar.gz
spark-cf38fe04f8782ff4573ae106ec0de8e8d183cb2b.tar.bz2
spark-cf38fe04f8782ff4573ae106ec0de8e8d183cb2b.zip
[SPARK-6844][SQL] Clean up accumulators used in InMemoryRelation when it is uncached
JIRA: https://issues.apache.org/jira/browse/SPARK-6844 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5475 from viirya/cache_memory_leak and squashes the following commits: 0b41235 [Liang-Chi Hsieh] fix style. dc1d5d5 [Liang-Chi Hsieh] For comments. 78af229 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cache_memory_leak 26c9bb6 [Liang-Chi Hsieh] Add configuration to enable in-memory table scan accumulators. 1c3b06e [Liang-Chi Hsieh] Clean up accumulators used in InMemoryRelation when it is uncached.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala2
4 files changed, 55 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index ca4a127120..18584c2dcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
require(dataIndex >= 0, s"Table $query is not cached.")
- cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ cachedData(dataIndex).cachedRepresentation.uncache(blocking)
cachedData.remove(dataIndex)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 6eee0c86d6..d9b6fb43ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import org.apache.spark.Accumulator
+import org.apache.spark.{Accumulable, Accumulator, Accumulators}
import org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
+import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
@@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation(
child: SparkPlan,
tableName: Option[String])(
private var _cachedColumnBuffers: RDD[CachedBatch] = null,
- private var _statistics: Statistics = null)
+ private var _statistics: Statistics = null,
+ private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null)
extends LogicalPlan with MultiInstanceRelation {
- private val batchStats =
- child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
+ private val batchStats: Accumulable[ArrayBuffer[Row], Row] =
+ if (_batchStats == null) {
+ child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row])
+ } else {
+ _batchStats
+ }
val partitionStatistics = new PartitionStatistics(output)
@@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation(
def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
InMemoryRelation(
newOutput, useCompression, batchSize, storageLevel, child, tableName)(
- _cachedColumnBuffers, statisticsToBePropagated)
+ _cachedColumnBuffers, statisticsToBePropagated, batchStats)
}
override def children: Seq[LogicalPlan] = Seq.empty
@@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation(
child,
tableName)(
_cachedColumnBuffers,
- statisticsToBePropagated).asInstanceOf[this.type]
+ statisticsToBePropagated,
+ batchStats).asInstanceOf[this.type]
}
def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
override protected def otherCopyArgs: Seq[AnyRef] =
- Seq(_cachedColumnBuffers, statisticsToBePropagated)
+ Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
+
+ private[sql] def uncache(blocking: Boolean): Unit = {
+ Accumulators.remove(batchStats.id)
+ cachedColumnBuffers.unpersist(blocking)
+ _cachedColumnBuffers = null
+ }
}
private[sql] case class InMemoryColumnarTableScan(
@@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan(
}
}
+ lazy val enableAccumulators: Boolean =
+ sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
+
// Accumulators used for testing purposes
- val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
- val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
+ lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
+ lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
override def execute(): RDD[Row] = {
- readPartitions.setValue(0)
- readBatches.setValue(0)
+ if (enableAccumulators) {
+ readPartitions.setValue(0)
+ readBatches.setValue(0)
+ }
relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
val partitionFilter = newPredicate(
@@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan(
}
}
- if (rows.hasNext) {
+ if (rows.hasNext && enableAccumulators) {
readPartitions += 1
}
@@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan(
logInfo(s"Skipping partition based on stats $statsString")
false
} else {
- readBatches += 1
+ if (enableAccumulators) {
+ readBatches += 1
+ }
true
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index f7b5f08beb..01e3b86710 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps}
import org.scalatest.concurrent.Eventually._
+import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.test.TestSQLContext._
@@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest {
sql("Clear CACHE")
assert(cacheManager.isEmpty)
}
+
+ test("Clear accumulators when uncacheTable to prevent memory leaking") {
+ val accsSize = Accumulators.originals.size
+
+ sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
+ sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
+ cacheTable("t1")
+ cacheTable("t2")
+ sql("SELECT * FROM t1").count()
+ sql("SELECT * FROM t2").count()
+ sql("SELECT * FROM t1").count()
+ sql("SELECT * FROM t2").count()
+ uncacheTable("t1")
+ uncacheTable("t2")
+
+ assert(accsSize >= Accumulators.originals.size)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index e57bb06e72..2a0b701cad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
+ // Enable in-memory table scan accumulators
+ setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {