aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/serializer/Serializer.scala
blob: 50b086125a55e2217f5877c54c708fae8f4e8a3b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package spark.serializer

import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream

/**
 * A serializer. Because some serialization libraries are not thread safe, this class is used to
 * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
 * guaranteed to only be called from one thread at a time.
 */
trait Serializer {
  def newInstance(): SerializerInstance
}

/**
 * An instance of a serializer, for use by one thread at a time.
 */
trait SerializerInstance {
  def serialize[T](t: T): ByteBuffer

  def deserialize[T](bytes: ByteBuffer): T

  def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T

  def serializeStream(s: OutputStream): SerializationStream

  def deserializeStream(s: InputStream): DeserializationStream

  def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
    // Default implementation uses serializeStream
    val stream = new FastByteArrayOutputStream()
    serializeStream(stream).writeAll(iterator)
    val buffer = ByteBuffer.allocate(stream.position.toInt)
    buffer.put(stream.array, 0, stream.position.toInt)
    buffer.flip()
    buffer
  }

  def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
    // Default implementation uses deserializeStream
    buffer.rewind()
    deserializeStream(new ByteBufferInputStream(buffer)).asIterator
  }
}

/**
 * A stream for writing serialized objects.
 */
trait SerializationStream {
  def writeObject[T](t: T): SerializationStream
  def flush(): Unit
  def close(): Unit

  def writeAll[T](iter: Iterator[T]): SerializationStream = {
    while (iter.hasNext) {
      writeObject(iter.next())
    }
    this
  }
}

/**
 * A stream for reading serialized objects.
 */
trait DeserializationStream {
  def readObject[T](): T
  def close(): Unit

  /**
   * Read the elements of this stream through an iterator. This can only be called once, as
   * reading each element will consume data from the input source.
   */
  def asIterator: Iterator[Any] = new Iterator[Any] {
    var gotNext = false
    var finished = false
    var nextValue: Any = null

    private def getNext() {
      try {
        nextValue = readObject[Any]()
      } catch {
        case eof: EOFException =>
          finished = true
      }
      gotNext = true
    }

    override def hasNext: Boolean = {
      if (!gotNext) {
        getNext()
      }
      if (finished) {
        close()
      }
      !finished
    }

    override def next(): Any = {
      if (!gotNext) {
        getNext()
      }
      if (finished) {
        throw new NoSuchElementException("End of stream")
      }
      gotNext = false
      nextValue
    }
  }
}