aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala16
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala6
12 files changed, 87 insertions, 15 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 459ef66712..2dfb00d7ec 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -96,6 +96,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-unsafe_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 959aefabd8..0c4d28f786 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}
/**
@@ -69,6 +70,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
@@ -382,6 +384,15 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
+ val executorMemoryManager: ExecutorMemoryManager = {
+ val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
+ MemoryAllocator.UNSAFE
+ } else {
+ MemoryAllocator.HEAP
+ }
+ new ExecutorMemoryManager(allocator)
+ }
+
val envInstance = new SparkEnv(
executorId,
rpcEnv,
@@ -398,6 +409,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ executorMemoryManager,
outputCommitCoordinator,
conf)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 7d7fe1a446..d09e17dea0 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable {
/** ::DeveloperApi:: */
@DeveloperApi
def taskMetrics(): TaskMetrics
+
+ /**
+ * Returns the manager for this task's managed memory.
+ */
+ private[spark] def taskMemoryManager(): TaskMemoryManager
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 337c8e4ebe..b4d572cb52 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import scala.collection.mutable.ArrayBuffer
@@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
+ override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index f57e215c3f..dd1c48e6cb 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
/**
@@ -178,6 +179,7 @@ private[spark] class Executor(
}
override def run(): Unit = {
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
@@ -190,6 +192,7 @@ private[spark] class Executor(
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
@@ -206,7 +209,21 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ val value = try {
+ task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ } finally {
+ // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
+ // when changing this, make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
+ if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(errMsg)
+ } else {
+ logError(errMsg)
+ }
+ }
+ }
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 8c4bff4e83..b7901c06a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -643,8 +644,15 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
- attemptNumber = 0, runningLocally = true)
+ val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
+ val taskContext =
+ new TaskContextImpl(
+ job.finalStage.id,
+ job.partitions(0),
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = true)
TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -652,6 +660,16 @@ class DAGScheduler(
} finally {
taskContext.markTaskCompleted()
TaskContext.unset()
+ // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
+ // make sure to update both copies.
+ val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
+ if (freedMemory > 0) {
+ if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
+ throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
+ } else {
+ logError(s"Managed memory leak detected; size = $freedMemory bytes")
+ }
+ }
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index b09b19e2ac..586d1e0620 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
@@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
* @return the result of the task
*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
- context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
- taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
+ context = new TaskContextImpl(
+ stageId = stageId,
+ partitionId = partitionId,
+ taskAttemptId = taskAttemptId,
+ attemptNumber = attemptNumber,
+ taskMemoryManager = taskMemoryManager,
+ runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
@@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
}
}
+ private var taskMemoryManager: TaskMemoryManager = _
+
+ def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
+ this.taskMemoryManager = taskMemoryManager
+ }
+
def runTask(context: TaskContext): T
def preferredLocations: Seq[TaskLocation] = Nil
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8a4f2a08fe..34ac9361d4 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1009,7 +1009,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 70529d9216..668ddf9f5f 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index aea76c1adc..85eb2a1d07 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0, 0, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 057e226916..83ae870124 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 37b593b2c5..2080c432d7 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0),
+ new TaskContextImpl(0, 0, 0, 0, null),
transfer,
blockManager,
blocksByAddress,
@@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,