aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/google/OAuth2.scala
blob: 43811c43b391479b70de7817623da64a309b8889 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package xyz.driver.tracing
package google

import akka.stream.scaladsl._
import akka.stream._
import akka.stream.stage._
import akka.http.scaladsl._
import akka.http.scaladsl.unmarshalling._
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.util.ByteString
import java.time._
import java.nio.file._
import spray.json._
import spray.json.DefaultJsonProtocol._
import pdi.jwt._
import scala.concurrent._
import scala.concurrent.duration._

object OAuth2 {

  private case class ServiceAccount(project_id: String,
                                    private_key: String,
                                    client_email: String)
  private implicit val serviceAccountFormat = jsonFormat3(ServiceAccount)

  private case class GrantResponse(access_token: String, expires_in: Int)
  private implicit val grantResponseFormat = jsonFormat2(GrantResponse)

  /** Request a new access token for the given scopes.
    *
    * Implements the OAUTH2 workflow as descried here
    * https://developers.google.com/identity/protocols/OAuth2ServiceAccount
    */
  def requestAccessToken(
      http: HttpExt,
      serviceAccountFile: Path,
      scopes: Seq[String]
  )(implicit ec: ExecutionContext,
    mat: Materializer): Future[(Instant, String)] =
    Future {
      val now = Instant.now.toEpochMilli / 1000
      val credentials =
        (new String(Files.readAllBytes(serviceAccountFile), "utf-8")).parseJson
          .convertTo[ServiceAccount]

      val claim = JwtClaim(
        issuer = Some(credentials.client_email),
        expiration = Some(now + 60 * 60),
        issuedAt = Some(now)
      ) +
        ("aud", "https://www.googleapis.com/oauth2/v4/token") +
        ("scope", scopes.mkString(" "))

      Jwt.encode(claim, credentials.private_key, JwtAlgorithm.RS256)
    } flatMap { assertion =>
      http.singleRequest(
        HttpRequest(
          method = HttpMethods.POST,
          uri = "https://www.googleapis.com/oauth2/v4/token"
        ).withEntity(
          FormData(
            "grant_type" -> "urn:ietf:params:oauth:grant-type:jwt-bearer",
            "assertion" -> assertion
          ).toEntity))
    } flatMap { response =>
      Unmarshal(response).to[GrantResponse]
    } map { grant =>
      (Instant.now.plusSeconds(grant.expires_in), grant.access_token)
    }

  /** Flow that injects access tokens into a stream of HTTP requests.
    *
    * Re-authentication happens transparently when access tokens expire. Note:
    * in case an access token gets revoked, this flow needs to be restarted in
    * order to re-authenticate
    */
  def authenticatedFlow(http: HttpExt,
                        serviceAccountFile: Path,
                        scopes: Seq[String],
                        graceSeconds: Int = 300)(
      implicit ec: ExecutionContext,
      mat: Materializer): Flow[HttpRequest, HttpRequest, _] =
    Flow[HttpRequest]
      .scanAsync[(HttpRequest, Instant, String)](
        (HttpRequest(), Instant.now, "")) {
        case ((_, expiration, accessToken), request) =>
          if (Instant.now isAfter expiration.minusSeconds(graceSeconds)) {
            http.system.log.info("tracing access token expired, refreshing")
            requestAccessToken(http, serviceAccountFile, scopes).map {
              case (newExpiration, newToken) =>
                http.system.log.debug("new tracing access token otained")
                (request, newExpiration, newToken)
            }
          } else {
            Future.successful((request, expiration, accessToken))
          }
      }
      .drop(1) // drop initial element
      .map {
        case (request, _, accessToken) =>
          request.withHeaders(
            RawHeader("Authorization", "Bearer " + accessToken)
          )
      }

}