aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
blob: d71069444a73f498eac40995755362f9afb4f9d9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.partial

import org.apache.spark._
import org.apache.spark.scheduler.JobListener
import org.apache.spark.rdd.RDD

/**
 * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
 * This listener waits up to timeout milliseconds and will return a partial answer even if the
 * complete answer is not available by then.
 *
 * This class assumes that the action is performed on an entire RDD[T] via a function that computes
 * a result of type U for each partition, and that the action returns a partial or complete result
 * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
 */
private[spark] class ApproximateActionListener[T, U, R](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    evaluator: ApproximateEvaluator[U, R],
    timeout: Long)
  extends JobListener {

  val startTime = System.currentTimeMillis()
  val totalTasks = rdd.partitions.size
  var finishedTasks = 0
  var failure: Option[Exception] = None             // Set if the job has failed (permanently)
  var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult

  override def taskSucceeded(index: Int, result: Any) {
    synchronized {
      evaluator.merge(index, result.asInstanceOf[U])
      finishedTasks += 1
      if (finishedTasks == totalTasks) {
        // If we had already returned a PartialResult, set its final value
        resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
        // Notify any waiting thread that may have called awaitResult
        this.notifyAll()
      }
    }
  }

  override def jobFailed(exception: Exception) {
    synchronized {
      failure = Some(exception)
      this.notifyAll()
    }
  }

  /**
   * Waits for up to timeout milliseconds since the listener was created and then returns a
   * PartialResult with the result so far. This may be complete if the whole job is done.
   */
  def awaitResult(): PartialResult[R] = synchronized {
    val finishTime = startTime + timeout
    while (true) {
      val time = System.currentTimeMillis()
      if (failure != None) {
        throw failure.get
      } else if (finishedTasks == totalTasks) {
        return new PartialResult(evaluator.currentResult(), true)
      } else if (time >= finishTime) {
        resultObject = Some(new PartialResult(evaluator.currentResult(), false))
        return resultObject.get
      } else {
        this.wait(finishTime - time)
      }
    }
    // Should never be reached, but required to keep the compiler happy
    return null
  }
}