aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
blob: b5385c11a926e2132002ede25c6590b9342bbe58 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.scheduler

import java.io.File
import java.net.URL
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.control.NonFatal

import com.google.common.util.concurrent.MoreExecutors
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, anyLong}
import org.mockito.Mockito.{spy, times, verify}
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark._
import org.apache.spark.storage.TaskResultBlockId
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}


/**
 * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
 *
 * Used to test the case where a BlockManager evicts the task result (or dies) before the
 * TaskResult is retrieved.
 */
private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
  extends TaskResultGetter(sparkEnv, scheduler) {
  var removedResult = false

  @volatile var removeBlockSuccessfully = false

  override def enqueueSuccessfulTask(
    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
    if (!removedResult) {
      // Only remove the result once, since we'd like to test the case where the task eventually
      // succeeds.
      serializer.get().deserialize[TaskResult[_]](serializedData) match {
        case IndirectTaskResult(blockId, size) =>
          sparkEnv.blockManager.master.removeBlock(blockId)
          // removeBlock is asynchronous. Need to wait it's removed successfully
          try {
            eventually(timeout(3 seconds), interval(200 milliseconds)) {
              assert(!sparkEnv.blockManager.master.contains(blockId))
            }
            removeBlockSuccessfully = true
          } catch {
            case NonFatal(e) => removeBlockSuccessfully = false
          }
        case directResult: DirectTaskResult[_] =>
          taskSetManager.abort("Internal error: expect only indirect results")
      }
      serializedData.rewind()
      removedResult = true
    }
    super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
  }
}


/**
 * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
 * _before_ modifying the results in any way.
 */
private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl)
  extends TaskResultGetter(env, scheduler) {

  // Use the current thread so we can access its results synchronously
  protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor()

  // DirectTaskResults that we receive from the executors
  private val _taskResults = new ArrayBuffer[DirectTaskResult[_]]

  def taskResults: Seq[DirectTaskResult[_]] = _taskResults

  override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = {
    // work on a copy since the super class still needs to use the buffer
    val newBuffer = data.duplicate()
    _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer)
    super.enqueueSuccessfulTask(tsm, tid, data)
  }
}


/**
 * Tests related to handling task results (both direct and indirect).
 */
class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  // Set the RPC message size to be as small as possible (it must be an integer, so 1 is as small
  // as we can make it) so the tests don't take too long.
  def conf: SparkConf = new SparkConf().set("spark.rpc.message.maxSize", "1")

  test("handling results smaller than max RPC message size") {
    sc = new SparkContext("local", "test", conf)
    val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
    assert(result === 2)
  }

  test("handling results larger than max RPC message size") {
    sc = new SparkContext("local", "test", conf)
    val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
    val result =
      sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x)
    assert(result === 1.to(maxRpcMessageSize).toArray)

    val RESULT_BLOCK_ID = TaskResultBlockId(0)
    assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
      "Expect result to be removed from the block manager.")
  }

  test("task retried if result missing from block manager") {
    // Set the maximum number of task failures to > 0, so that the task set isn't aborted
    // after the result is missing.
    sc = new SparkContext("local[1,2]", "test", conf)
    // If this test hangs, it's probably because no resource offers were made after the task
    // failed.
    val scheduler: TaskSchedulerImpl = sc.taskScheduler match {
      case taskScheduler: TaskSchedulerImpl =>
        taskScheduler
      case _ =>
        assert(false, "Expect local cluster to use TaskSchedulerImpl")
        throw new ClassCastException
    }
    val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
    scheduler.taskResultGetter = resultGetter
    val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
    val result =
      sc.parallelize(Seq(1), 1).map(x => 1.to(maxRpcMessageSize).toArray).reduce((x, y) => x)
    assert(resultGetter.removeBlockSuccessfully)
    assert(result === 1.to(maxRpcMessageSize).toArray)

    // Make sure two tasks were run (one failed one, and a second retried one).
    assert(scheduler.nextTaskId.get() === 2)
  }

  /**
   * Make sure we are using the context classloader when deserializing failed TaskResults instead
   * of the Spark classloader.

   * This test compiles a jar containing an exception and tests that when it is thrown on the
   * executor, enqueueFailedTask can correctly deserialize the failure and identify the thrown
   * exception as the cause.

   * Before this fix, enqueueFailedTask would throw a ClassNotFoundException when deserializing
   * the exception, resulting in an UnknownReason for the TaskEndResult.
   */
  test("failed task deserialized with the correct classloader (SPARK-11195)") {
    // compile a small jar containing an exception that will be thrown on an executor.
    val tempDir = Utils.createTempDir()
    val srcDir = new File(tempDir, "repro/")
    srcDir.mkdirs()
    val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath,
      """package repro;
        |
        |public class MyException extends Exception {
        |}
      """.stripMargin)
    val excFile = TestUtils.createCompiledClass("MyException", srcDir, excSource, Seq.empty)
    val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
    TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("repro"))

    // ensure we reset the classloader after the test completes
    val originalClassLoader = Thread.currentThread.getContextClassLoader
    try {
      // load the exception from the jar
      val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader)
      loader.addURL(jarFile.toURI.toURL)
      Thread.currentThread().setContextClassLoader(loader)
      val excClass: Class[_] = Utils.classForName("repro.MyException")

      // NOTE: we must run the cluster with "local" so that the executor can load the compiled
      // jar.
      sc = new SparkContext("local", "test", conf)
      val rdd = sc.parallelize(Seq(1), 1).map { _ =>
        val exc = excClass.newInstance().asInstanceOf[Exception]
        throw exc
      }

      // the driver should not have any problems resolving the exception class and determining
      // why the task failed.
      val exceptionMessage = intercept[SparkException] {
        rdd.collect()
      }.getMessage

      val expectedFailure = """(?s).*Lost task.*: repro.MyException.*""".r
      val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r

      assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined)
      assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty)
    } finally {
      Thread.currentThread.setContextClassLoader(originalClassLoader)
    }
  }

  test("task result size is set on the driver, not the executors") {
    import InternalAccumulator._

    // Set up custom TaskResultGetter and TaskSchedulerImpl spy
    sc = new SparkContext("local", "test", conf)
    val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
    val spyScheduler = spy(scheduler)
    val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler)
    val newDAGScheduler = new DAGScheduler(sc, spyScheduler)
    scheduler.taskResultGetter = resultGetter
    sc.dagScheduler = newDAGScheduler
    sc.taskScheduler = spyScheduler
    sc.taskScheduler.setDAGScheduler(newDAGScheduler)

    // Just run 1 task and capture the corresponding DirectTaskResult
    sc.parallelize(1 to 1, 1).count()
    val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]])
    verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture())

    // When a task finishes, the executor sends a serialized DirectTaskResult to the driver
    // without setting the result size so as to avoid serializing the result again. Instead,
    // the result size is set later in TaskResultGetter on the driver before passing the
    // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult
    // before and after the result size is set.
    assert(resultGetter.taskResults.size === 1)
    val resBefore = resultGetter.taskResults.head
    val resAfter = captor.getValue
    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
    assert(resSizeBefore.exists(_ == 0L))
    assert(resSizeAfter.exists(_.toString.toLong > 0L))
  }

}