aboutsummaryrefslogtreecommitdiff
path: root/core-storage/src/main/scala/xyz/driver/core/storage/channelStreams.scala
blob: fc652bed31bf35434535110c655683f2383185fa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package xyz.driver.core.storage

import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, WritableByteChannel}

import akka.stream._
import akka.stream.scaladsl.{Sink, Source}
import akka.stream.stage._
import akka.util.ByteString
import akka.{Done, NotUsed}

import scala.concurrent.{Future, Promise}
import scala.util.control.NonFatal

class ChannelSource(createChannel: () => ReadableByteChannel, chunkSize: Int)
    extends GraphStage[SourceShape[ByteString]] {

  val out   = Outlet[ByteString]("ChannelSource.out")
  val shape = SourceShape(out)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
    val channel = createChannel()

    object Handler extends OutHandler {
      override def onPull(): Unit = {
        try {
          val buffer = ByteBuffer.allocate(chunkSize)
          if (channel.read(buffer) > 0) {
            buffer.flip()
            push(out, ByteString.fromByteBuffer(buffer))
          } else {
            completeStage()
          }
        } catch {
          case NonFatal(_) =>
            channel.close()
        }
      }
      override def onDownstreamFinish(): Unit = {
        channel.close()
      }
    }

    setHandler(out, Handler)
  }

}

class ChannelSink(createChannel: () => WritableByteChannel, chunkSize: Int)
    extends GraphStageWithMaterializedValue[SinkShape[ByteString], Future[Done]] {

  val in    = Inlet[ByteString]("ChannelSink.in")
  val shape = SinkShape(in)

  override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Done]) = {
    val promise = Promise[Done]()
    val logic = new GraphStageLogic(shape) {
      val channel = createChannel()

      object Handler extends InHandler {
        override def onPush(): Unit = {
          try {
            val data = grab(in)
            channel.write(data.asByteBuffer)
            pull(in)
          } catch {
            case NonFatal(e) =>
              channel.close()
              promise.failure(e)
          }
        }

        override def onUpstreamFinish(): Unit = {
          channel.close()
          completeStage()
          promise.success(Done)
        }

        override def onUpstreamFailure(ex: Throwable): Unit = {
          channel.close()
          promise.failure(ex)
        }
      }

      setHandler(in, Handler)

      override def preStart(): Unit = {
        pull(in)
      }
    }
    (logic, promise.future)
  }

}

object ChannelStream {

  def fromChannel(channel: () => ReadableByteChannel, chunkSize: Int = 8192): Source[ByteString, NotUsed] = {
    Source
      .fromGraph(new ChannelSource(channel, chunkSize))
      .withAttributes(Attributes(ActorAttributes.IODispatcher))
      .async
  }

  def toChannel(channel: () => WritableByteChannel, chunkSize: Int = 8192): Sink[ByteString, Future[Done]] = {
    Sink
      .fromGraph(new ChannelSink(channel, chunkSize))
      .withAttributes(Attributes(ActorAttributes.IODispatcher))
      .async
  }

}