aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/messaging/GoogleBus.scala
blob: b296c501c838a85b6777087386f6e3882f607520 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
package xyz.driver.core
package messaging

import java.nio.ByteBuffer
import java.nio.file.{Files, Path, Paths}
import java.security.Signature
import java.time.Instant
import java.util

import com.google.auth.oauth2.ServiceAccountCredentials
import com.softwaremill.sttp._
import spray.json.DefaultJsonProtocol._
import spray.json._

import scala.async.Async.{async, await}
import scala.concurrent._
import scala.concurrent.duration._

/** A message bus implemented by [[https://cloud.google.com/pubsub/docs/overview Google's Pub/Sub service.]]
  *
  * == Overview ==
  *
  * The Pub/Sub message system is focused around a few concepts: 'topics',
  * 'subscriptions' and 'subscribers'. Messages are sent to ''topics'' which may
  * have multiple ''subscriptions'' associated to them. Every subscription to a
  * topic will receive all messages sent to the topic.  Messages are enqueued in
  * a subscription until they are acknowledged by a ''subscriber''.  Multiple
  * subscribers may be associated to a subscription, in which case messages will
  * get delivered arbitrarily among them.
  *
  * Topics and subscriptions are named resources which can be specified in
  * Pub/Sub's configuration and may be queried. Subscribers on the other hand,
  * are ephemeral processes that query a subscription on a regular basis, handle any
  * messages and acknowledge them.
  *
  * == Delivery semantics ==
  *
  *   - at least once
  *   - no ordering
  *
  * == Retention ==
  *
  *   - configurable retry delay for unacknowledged messages, defaults to 10s
  *   - undeliverable messages are kept for 7 days
  *
  * @param credentials Google cloud credentials, usually the same as used by a
  *                    service. Must have admin access to topics and
  *                    descriptions.
  * @param namespace The namespace in which this bus is running. Will be used to
  *                  determine the exact name of topics and subscriptions.
  * @param pullTimeout Delay after which a call to fetchMessages() will return an
  *                    empty list, assuming that no messages have been received.
  * @param executionContext Execution context to run any blocking commands.
  * @param backend sttp backend used to query Pub/Sub's HTTP API
  */
class GoogleBus(
    credentials: ServiceAccountCredentials,
    namespace: String,
    pullTimeout: Duration = 90.seconds
)(implicit val executionContext: ExecutionContext, backend: SttpBackend[Future, _])
    extends Bus {
  import GoogleBus.Protocol

  case class MessageId(subscription: String, ackId: String)

  case class PubsubMessage[A](id: MessageId, data: A, publishTime: Instant) extends super.BasicMessage[A]
  type Message[A] = PubsubMessage[A]

  /** Subscription-specific configuration
    *
    * @param subscriptionPrefix An identifier used to uniquely determine the name of Pub/Sub subscriptions.
    *                           All messages sent to a subscription will be dispatched arbitrarily
    *                           among any subscribers. Defaults to the email of the credentials used by this
    *                           bus instance, thereby giving every service a unique subscription to every topic.
    *                           To give every service instance a unique subscription, this must be changed to a
    *                           unique value.
    * @param ackTimeout Duration in which a message must be acknowledged before it is delivered again.
    */
  case class SubscriptionConfig(
      subscriptionPrefix: String = credentials.getClientEmail.split("@")(0),
      ackTimeout: FiniteDuration = 10.seconds
  )
  override val defaultSubscriptionConfig: SubscriptionConfig = SubscriptionConfig()

  /** Obtain an authentication token valid for the given duration
    * https://developers.google.com/identity/protocols/OAuth2ServiceAccount
    */
  private def freshAuthToken(duration: FiniteDuration): Future[String] = {
    def jwt = {
      val now    = Instant.now().getEpochSecond
      val base64 = util.Base64.getEncoder
      val header = base64.encodeToString("""{"alg":"RS256","typ":"JWT"}""".getBytes("utf-8"))
      val body = base64.encodeToString(
        s"""|{
            | "iss": "${credentials.getClientEmail}",
            | "scope": "https://www.googleapis.com/auth/pubsub",
            | "aud": "https://www.googleapis.com/oauth2/v4/token",
            | "exp": ${now + duration.toSeconds},
            | "iat": $now
            |}""".stripMargin.getBytes("utf-8")
      )
      val signer = Signature.getInstance("SHA256withRSA")
      signer.initSign(credentials.getPrivateKey)
      signer.update(s"$header.$body".getBytes("utf-8"))
      val signature = base64.encodeToString(signer.sign())
      s"$header.$body.$signature"
    }
    sttp
      .post(uri"https://www.googleapis.com/oauth2/v4/token")
      .body(
        "grant_type" -> "urn:ietf:params:oauth:grant-type:jwt-bearer",
        "assertion"  -> jwt
      )
      .mapResponse(s => s.parseJson.asJsObject.fields("access_token").convertTo[String])
      .send()
      .map(_.unsafeBody)
  }

  // the token is cached a few minutes less than its validity to diminish latency of concurrent accesses at renewal time
  private val getToken: () => Future[String] = Refresh.every(55.minutes)(freshAuthToken(60.minutes))

  private val baseUri = uri"https://pubsub.googleapis.com/"

  private def rawTopicName(topic: Topic[_]) =
    s"projects/${credentials.getProjectId}/topics/$namespace.${topic.name}"
  private def rawSubscriptionName(config: SubscriptionConfig, topic: Topic[_]) =
    s"projects/${credentials.getProjectId}/subscriptions/$namespace.${config.subscriptionPrefix}.${topic.name}"

  def createTopic(topic: Topic[_]): Future[Unit] = async {
    val request = sttp
      .put(baseUri.path(s"v1/${rawTopicName(topic)}"))
      .auth
      .bearer(await(getToken()))
    val result = await(request.send())
    result.body match {
      case Left(error) if result.code != 409 => // 409 <=> topic already exists, ignore it
        throw new NoSuchElementException(s"Error creating topic: Status code ${result.code}: $error")
      case _ => ()
    }
  }

  def createSubscription(topic: Topic[_], config: SubscriptionConfig): Future[Unit] = async {
    val request = sttp
      .put(baseUri.path(s"v1/${rawSubscriptionName(config, topic)}"))
      .auth
      .bearer(await(getToken()))
      .body(
        JsObject(
          "topic"              -> rawTopicName(topic).toJson,
          "ackDeadlineSeconds" -> config.ackTimeout.toSeconds.toJson
        ).compactPrint
      )
    val result = await(request.send())
    result.body match {
      case Left(error) if result.code != 409 => // 409 <=> subscription already exists, ignore it
        throw new NoSuchElementException(s"Error creating subscription: Status code ${result.code}: $error")
      case _ => ()
    }
  }

  override def publishMessages[A](topic: Topic[A], messages: Seq[A]): Future[Unit] = async {
    import Protocol.bufferFormat
    val buffers: Seq[ByteBuffer] = messages.map(topic.serialize)
    val request = sttp
      .post(baseUri.path(s"v1/${rawTopicName(topic)}:publish"))
      .auth
      .bearer(await(getToken()))
      .body(
        JsObject("messages" -> buffers.map(buffer => JsObject("data" -> buffer.toJson)).toJson).compactPrint
      )
    await(request.send()).unsafeBody
    ()
  }

  override def fetchMessages[A](
      topic: Topic[A],
      subscriptionConfig: SubscriptionConfig,
      maxMessages: Int): Future[Seq[PubsubMessage[A]]] = async {
    val subscription = rawSubscriptionName(subscriptionConfig, topic)
    val request = sttp
      .post(baseUri.path(s"v1/$subscription:pull"))
      .auth
      .bearer(await(getToken().map(x => x)))
      .body(
        JsObject(
          "returnImmediately" -> JsFalse,
          "maxMessages"       -> JsNumber(maxMessages)
        ).compactPrint
      )
      .readTimeout(pullTimeout)
      .mapResponse(_.parseJson)

    val messages = await(request.send()).unsafeBody match {
      case JsObject(fields) if fields.isEmpty => Seq()
      case obj                                => obj.convertTo[Protocol.SubscriptionPull].receivedMessages
    }

    messages.map { msg =>
      PubsubMessage[A](
        MessageId(subscription, msg.ackId),
        topic.deserialize(msg.message.data),
        msg.message.publishTime
      )
    }
  }

  override def acknowledgeMessages(messageIds: Seq[MessageId]): Future[Unit] = async {
    val request = sttp
      .post(baseUri.path(s"v1/${messageIds.head.subscription}:acknowledge"))
      .auth
      .bearer(await(getToken()))
      .body(
        JsObject("ackIds" -> JsArray(messageIds.toVector.map(m => JsString(m.ackId)))).compactPrint
      )
    await(request.send()).unsafeBody
    ()
  }

}

object GoogleBus {

  private object Protocol extends DefaultJsonProtocol {
    case class SubscriptionPull(receivedMessages: Seq[ReceivedMessage])
    case class ReceivedMessage(ackId: String, message: PubsubMessage)
    case class PubsubMessage(data: ByteBuffer, publishTime: Instant)

    implicit val timeFormat: JsonFormat[Instant] = new JsonFormat[Instant] {
      override def write(obj: Instant): JsValue = JsString(obj.toString)
      override def read(json: JsValue): Instant = Instant.parse(json.convertTo[String])
    }
    implicit val bufferFormat: JsonFormat[ByteBuffer] = new JsonFormat[ByteBuffer] {
      override def write(obj: ByteBuffer): JsValue =
        JsString(util.Base64.getEncoder.encodeToString(obj.array()))

      override def read(json: JsValue): ByteBuffer = {
        val encodedBytes = json.convertTo[String].getBytes("utf-8")
        val decodedBytes = util.Base64.getDecoder.decode(encodedBytes)
        ByteBuffer.wrap(decodedBytes)
      }
    }

    implicit val pubsubMessageFormat: RootJsonFormat[PubsubMessage]      = jsonFormat2(PubsubMessage)
    implicit val receivedMessageFormat: RootJsonFormat[ReceivedMessage]  = jsonFormat2(ReceivedMessage)
    implicit val subscrptionPullFormat: RootJsonFormat[SubscriptionPull] = jsonFormat1(SubscriptionPull)
  }

  def fromKeyfile(keyfile: Path, namespace: String)(
      implicit executionContext: ExecutionContext,
      backend: SttpBackend[Future, _]): GoogleBus = {
    val creds = ServiceAccountCredentials.fromStream(Files.newInputStream(keyfile))
    new GoogleBus(creds, namespace)
  }

  @deprecated(
    "Reading from the environment adds opaque dependencies and hance leads to extra complexity. Use fromKeyfile instead.",
    "driver-core 1.12.2")
  def fromEnv(implicit executionContext: ExecutionContext, backend: SttpBackend[Future, _]): GoogleBus = {
    def env(key: String) = {
      require(sys.env.contains(key), s"Environment variable $key is not set.")
      sys.env(key)
    }
    val keyfile = Paths.get(env("GOOGLE_APPLICATION_CREDENTIALS"))
    fromKeyfile(keyfile, env("SERVICE_NAMESPACE"))
  }

}