aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/spark/ThreadingSuite.scala
blob: ff315b66935dc997631c9accfd8bbf2b8a1d5bb8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package spark

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter

import SparkContext._

/**
 * Holds state shared across task threads in some ThreadingSuite tests.
 */
object ThreadingSuiteState {
  val runningThreads = new AtomicInteger
  val failed = new AtomicBoolean

  def clear() {
    runningThreads.set(0)
    failed.set(false)
  }
}

class ThreadingSuite extends FunSuite with LocalSparkContext {
  
  test("accessing SparkContext form a different thread") {
    sc = new SparkContext("local", "test")
    val nums = sc.parallelize(1 to 10, 2)
    val sem = new Semaphore(0)
    @volatile var answer1: Int = 0
    @volatile var answer2: Int = 0
    new Thread {
      override def run() {
        answer1 = nums.reduce(_ + _)
        answer2 = nums.first()    // This will run "locally" in the current thread
        sem.release()
      }
    }.start()
    sem.acquire()
    assert(answer1 === 55)
    assert(answer2 === 1)
  }

  test("accessing SparkContext form multiple threads") {
    sc = new SparkContext("local", "test")
    val nums = sc.parallelize(1 to 10, 2)
    val sem = new Semaphore(0)
    @volatile var ok = true
    for (i <- 0 until 10) {
      new Thread {
        override def run() {
          val answer1 = nums.reduce(_ + _)
          if (answer1 != 55) {
            printf("In thread %d: answer1 was %d\n", i, answer1)
            ok = false
          }
          val answer2 = nums.first()    // This will run "locally" in the current thread
          if (answer2 != 1) {
            printf("In thread %d: answer2 was %d\n", i, answer2)
            ok = false
          }
          sem.release()
        }
      }.start()
    }
    sem.acquire(10)
    if (!ok) {
      fail("One or more threads got the wrong answer from an RDD operation")
    }
  }

  test("accessing multi-threaded SparkContext form multiple threads") {
    sc = new SparkContext("local[4]", "test")
    val nums = sc.parallelize(1 to 10, 2)
    val sem = new Semaphore(0)
    @volatile var ok = true
    for (i <- 0 until 10) {
      new Thread {
        override def run() {
          val answer1 = nums.reduce(_ + _)
          if (answer1 != 55) {
            printf("In thread %d: answer1 was %d\n", i, answer1)
            ok = false
          }
          val answer2 = nums.first()    // This will run "locally" in the current thread
          if (answer2 != 1) {
            printf("In thread %d: answer2 was %d\n", i, answer2)
            ok = false
          }
          sem.release()
        }
      }.start()
    }
    sem.acquire(10)
    if (!ok) {
      fail("One or more threads got the wrong answer from an RDD operation")
    }
  }

  test("parallel job execution") {
    // This test launches two jobs with two threads each on a 4-core local cluster. Each thread
    // waits until there are 4 threads running at once, to test that both jobs have been launched.
    sc = new SparkContext("local[4]", "test")
    val nums = sc.parallelize(1 to 2, 2)
    val sem = new Semaphore(0)
    ThreadingSuiteState.clear()
    for (i <- 0 until 2) {
      new Thread {
        override def run() {
          val ans = nums.map(number => {
            val running = ThreadingSuiteState.runningThreads
            running.getAndIncrement()
            val time = System.currentTimeMillis()
            while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
              Thread.sleep(100)
            }
            if (running.get() != 4) {
              println("Waited 1 second without seeing runningThreads = 4 (it was " +
                running.get() + "); failing test")
              ThreadingSuiteState.failed.set(true)
            }
            number
          }).collect()
          assert(ans.toList === List(1, 2))
          sem.release()
        }
      }.start()
    }
    sem.acquire(2)
    if (ThreadingSuiteState.failed.get()) {
      fail("One or more threads didn't see runningThreads = 4")
    }
  }
}