aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/spark/DfsShuffle.scala
blob: 7a42bf2d06f624b40ad7d9c17ba01e9849524772 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package spark

import java.io.{EOFException, ObjectInputStream, ObjectOutputStream}
import java.net.URI
import java.util.UUID

import scala.collection.mutable.HashMap

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}


/**
 * A simple implementation of shuffle using a distributed file system.
 *
 * TODO: Add support for compression when spark.compress is set to true.
 */
@serializable
class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
  override def compute(input: RDD[(K, V)],
                       numOutputSplits: Int,
                       createCombiner: V => C,
                       mergeValue: (C, V) => C,
                       mergeCombiners: (C, C) => C)
  : RDD[(K, C)] =
  {
    val sc = input.sparkContext
    val dir = DfsShuffle.newTempDirectory()
    logInfo("Intermediate data directory: " + dir)

    val numberedSplitRdd = new NumberedSplitRDD(input)
    val numInputSplits = numberedSplitRdd.splits.size

    // Run a parallel foreach to write the intermediate data files
    numberedSplitRdd.foreach((pair: (Int, Iterator[(K, V)])) => {
      val myIndex = pair._1
      val myIterator = pair._2
      val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
      for ((k, v) <- myIterator) {
        var bucketId = k.hashCode % numOutputSplits
        if (bucketId < 0) { // Fix bucket ID if hash code was negative
          bucketId += numOutputSplits
        }
        val bucket = buckets(bucketId)
        bucket(k) = bucket.get(k) match {
          case Some(c) => mergeValue(c, v)
          case None => createCombiner(v)
        }
      }
      val fs = DfsShuffle.getFileSystem()
      for (i <- 0 until numOutputSplits) {
        val path = new Path(dir, "%d-to-%d".format(myIndex, i))
        val out = new ObjectOutputStream(fs.create(path, true))
        buckets(i).foreach(pair => out.writeObject(pair))
        out.close()
      }
    })

    // Return an RDD that does each of the merges for a given partition
    val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
    return indexes.flatMap((myIndex: Int) => {
      val combiners = new HashMap[K, C]
      val fs = DfsShuffle.getFileSystem()
      for (i <- Utils.shuffle(0 until numInputSplits)) {
        val path = new Path(dir, "%d-to-%d".format(i, myIndex))
        val inputStream = new ObjectInputStream(fs.open(path))
        try {
          while (true) {
            val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
            combiners(k) = combiners.get(k) match {
              case Some(oldC) => mergeCombiners(oldC, c)
              case None => c
            }
          }
        } catch {
          case e: EOFException => {}
        }
        inputStream.close()
      }
      combiners
    })
  }
}


/**
 * Companion object of DfsShuffle; responsible for initializing a Hadoop
 * FileSystem object based on the spark.dfs property and generating names
 * for temporary directories.
 */
object DfsShuffle {
  private var initialized = false
  private var fileSystem: FileSystem = null

  private def initializeIfNeeded() = synchronized {
    if (!initialized) {
      val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
      val dfs = System.getProperty("spark.dfs", "file:///")
      val conf = new Configuration()
      conf.setInt("io.file.buffer.size", bufferSize)
      conf.setInt("dfs.replication", 1)
      fileSystem = FileSystem.get(new URI(dfs), conf)
      initialized = true
    }
  }

  def getFileSystem(): FileSystem = {
    initializeIfNeeded()
    return fileSystem
  }

  def newTempDirectory(): String = {
    val fs = getFileSystem()
    val workDir = System.getProperty("spark.dfs.workdir", "/tmp")
    val uuid = UUID.randomUUID()
    val path = workDir + "/shuffle-" + uuid
    fs.mkdirs(new Path(path))
    return path
  }
}