aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/storage/BlockStore.scala
blob: d505df66a7d72f4548d92c68400372a8295c4ba8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package spark.storage

import java.io.{File, FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
import java.util.{LinkedHashMap, UUID}
import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap}

import scala.collection.mutable.ArrayBuffer

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream

import spark.{Utils, Logging, Serializer, SizeEstimator}

/**
 * Abstract class to store blocks
 */
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
  initLogging()

  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) 

  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel)
  : Either[Iterator[Any], ByteBuffer]

  /**
   * Return the size of a block.
   */
  def getSize(blockId: String): Long

  def getBytes(blockId: String): Option[ByteBuffer]

  def getValues(blockId: String): Option[Iterator[Any]]

  def remove(blockId: String)

  def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)

  def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)

  def clear() { }
}

/**
 * Class to store blocks in memory 
 */
class MemoryStore(blockManager: BlockManager, maxMemory: Long) 
  extends BlockStore(blockManager) {

  case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false)
  
  private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
  private var currentMemory = 0L
 
  //private val blockDropper = Executors.newSingleThreadExecutor() 
  private val blocksToDrop = new ArrayBlockingQueue[String](10000, true)
  private val blockDropper = new Thread("memory store - block dropper") {
    override def run() {
      try{
        while (true) {
          val blockId = blocksToDrop.take()
          logDebug("Block " + blockId + " ready to be dropped")
          blockManager.dropFromMemory(blockId)
        }
      } catch {
        case ie: InterruptedException => 
          logInfo("Shutting down block dropper")
      }
    }
  }
  blockDropper.start()
  logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory)))

  def freeMemory: Long = maxMemory - currentMemory

  def getSize(blockId: String): Long = memoryStore.synchronized { memoryStore.get(blockId).size }
  
  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
    if (level.deserialized) {
      bytes.rewind()
      val values = dataDeserialize(bytes)
      val elements = new ArrayBuffer[Any]
      elements ++= values
      val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
      ensureFreeSpace(sizeEstimate)
      val entry = new Entry(elements, sizeEstimate, true)
      memoryStore.synchronized { memoryStore.put(blockId, entry) }
      currentMemory += sizeEstimate
      logInfo("Block %s stored as values to memory (estimated size %d, free %d)".format(
        blockId, sizeEstimate, freeMemory))
    } else {
      val entry = new Entry(bytes, bytes.limit, false)
      ensureFreeSpace(bytes.limit)
      memoryStore.synchronized { memoryStore.put(blockId, entry) }
      currentMemory += bytes.limit
      logInfo("Block %s stored as %d bytes to memory (free %d)".format(
        blockId, bytes.limit, freeMemory))
    }
  }

  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel)
  : Either[Iterator[Any], ByteBuffer] = {
    if (level.deserialized) {
      val elements = new ArrayBuffer[Any]
      elements ++= values
      val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
      ensureFreeSpace(sizeEstimate)
      val entry = new Entry(elements, sizeEstimate, true)
      memoryStore.synchronized { memoryStore.put(blockId, entry) }
      currentMemory += sizeEstimate
      logInfo("Block %s stored as values to memory (estimated size %d, free %d)".format(
        blockId, sizeEstimate, freeMemory))
      return Left(elements.iterator) 
    } else {
      val bytes = dataSerialize(values)
      ensureFreeSpace(bytes.limit)
      val entry = new Entry(bytes, bytes.limit, false)
      memoryStore.synchronized { memoryStore.put(blockId, entry) } 
      currentMemory += bytes.limit
      logInfo("Block %s stored as %d bytes to memory (free %d)".format(
        blockId, bytes.limit, freeMemory))
      return Right(bytes)
    }
  }

  def getBytes(blockId: String): Option[ByteBuffer] = {
    throw new UnsupportedOperationException("Not implemented") 
  }

  def getValues(blockId: String): Option[Iterator[Any]] = {
    val entry = memoryStore.synchronized { memoryStore.get(blockId) }
    if (entry == null) {
      return None 
    }
    if (entry.deserialized) {
      return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)
    } else {
      return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer].duplicate())) 
    }
  }

  def remove(blockId: String) {
    memoryStore.synchronized {
      val entry = memoryStore.get(blockId) 
      if (entry != null) {
        memoryStore.remove(blockId)
        currentMemory -= entry.size
        logInfo("Block %s of size %d dropped from memory (free %d)".format(
          blockId, entry.size, freeMemory))
      } else {
        logWarning("Block " + blockId + " could not be removed as it doesnt exist")
      }
    }
  }

  override def clear() {
    memoryStore.synchronized {
      memoryStore.clear()
    }
    //blockDropper.shutdown()
    blockDropper.interrupt()
    logInfo("MemoryStore cleared")
  }

  private def ensureFreeSpace(space: Long) {
    logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
      space, currentMemory, maxMemory))
    
    if (maxMemory - currentMemory < space) {

      val selectedBlocks = new ArrayBuffer[String]()
      var selectedMemory = 0L
      
      memoryStore.synchronized {
        val iter = memoryStore.entrySet().iterator()
        while (maxMemory - (currentMemory - selectedMemory) < space && iter.hasNext) {
          val pair = iter.next()
          val blockId = pair.getKey
          val entry = pair.getValue()
          if (!entry.dropPending) {
            selectedBlocks += blockId
            entry.dropPending = true
          }
          selectedMemory += pair.getValue.size
          logInfo("Block " + blockId + " selected for dropping")
        }
      }  
      
      logInfo("" + selectedBlocks.size + " new blocks selected for dropping, " + 
        blocksToDrop.size + " blocks pending")
      var i = 0 
      while (i < selectedBlocks.size) {
        blocksToDrop.add(selectedBlocks(i))
        i += 1
      }
      selectedBlocks.clear()      
    }
  }      
}


/**
 * Class to store blocks in disk 
 */
class DiskStore(blockManager: BlockManager, rootDirs: String) 
  extends BlockStore(blockManager) {

  val MAX_DIR_CREATION_ATTEMPTS: Int = 10
  val localDirs = createLocalDirs()
  var lastLocalDirUsed = 0

  addShutdownHook()

  def getSize(blockId: String): Long = {
    getFile(blockId).length
  }

  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
    logDebug("Attempting to put block " + blockId)
    val startTime = System.currentTimeMillis
    val file = createFile(blockId)
    val channel = new RandomAccessFile(file, "rw").getChannel()
    val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.limit)
    buffer.put(bytes)
    channel.close()
    val finishTime = System.currentTimeMillis
    logDebug("Block %s stored to file of %d bytes to disk in %d ms".format(
      blockId, bytes.limit, (finishTime - startTime)))
  }

  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel)
      : Either[Iterator[Any], ByteBuffer] = {

    logDebug("Attempting to write values for block " + blockId)
    val file = createFile(blockId)
    val fileOut = new FastBufferedOutputStream(new FileOutputStream(file))
    val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
    objOut.writeAll(values)
    objOut.close()

    // Return a byte buffer for the contents of the file
    val channel = new RandomAccessFile(file, "rw").getChannel()
    Right(channel.map(MapMode.READ_WRITE, 0, channel.size()))
  }

  def getBytes(blockId: String): Option[ByteBuffer] = {
    val file = getFile(blockId) 
    val length = file.length().toInt
    val channel = new RandomAccessFile(file, "r").getChannel()
    Some(channel.map(MapMode.READ_WRITE, 0, length))
  }

  def getValues(blockId: String): Option[Iterator[Any]] = {
    val file = getFile(blockId) 
    val length = file.length().toInt
    val channel = new RandomAccessFile(file, "r").getChannel()
    val bytes = channel.map(MapMode.READ_ONLY, 0, length)
    val buffer = dataDeserialize(bytes)
    channel.close()
    return Some(buffer) 
  }

  def remove(blockId: String) {
    throw new UnsupportedOperationException("Not implemented") 
  }
  
  private def createFile(blockId: String): File = {
    val file = getFile(blockId) 
    if (file == null) {
      lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
      val newFile = new File(localDirs(lastLocalDirUsed), blockId)
      newFile.getParentFile.mkdirs()
      return newFile 
    } else {
      throw new Exception("File for block " + blockId + " already exists on disk, " + file)
    }
  }

  private def getFile(blockId: String): File = {
    logDebug("Getting file for block " + blockId)
    // Search for the file in all the local directories, only one of them should have the file
    val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)  
    if (files.size > 1) {
      throw new Exception("Multiple files for same block " + blockId + " exists: " + 
        files.map(_.toString).reduceLeft(_ + ", " + _))
      return null
    } else if (files.size == 0) {
      return null 
    } else {
      logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
      return files(0)
    }
  }

  private def createLocalDirs(): Seq[File] = {
    logDebug("Creating local directories at root dirs '" + rootDirs + "'") 
    rootDirs.split("[;,:]").map(rootDir => {
        var foundLocalDir: Boolean = false
        var localDir: File = null
        var localDirUuid: UUID = null
        var tries = 0
        while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
          tries += 1
          try {
            localDirUuid = UUID.randomUUID()
            localDir = new File(rootDir, "spark-local-" + localDirUuid)
            if (!localDir.exists) {
              localDir.mkdirs()
              foundLocalDir = true
            }
          } catch {
            case e: Exception =>
            logWarning("Attempt " + tries + " to create local dir failed", e)
          }
        }
        if (!foundLocalDir) {
          logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + 
            " attempts to create local dir in " + rootDir)
          System.exit(1)
        }
        logDebug("Created local directory at " + localDir)
        localDir
    })
  }

  private def addShutdownHook() {
    Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
      override def run() {
        logDebug("Shutdown hook called")
        localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
      }
    })
  }
}