aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/google
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/google')
-rw-r--r--src/main/scala/google/OAuth2.scala108
-rw-r--r--src/main/scala/google/api.scala113
2 files changed, 221 insertions, 0 deletions
diff --git a/src/main/scala/google/OAuth2.scala b/src/main/scala/google/OAuth2.scala
new file mode 100644
index 0000000..43811c4
--- /dev/null
+++ b/src/main/scala/google/OAuth2.scala
@@ -0,0 +1,108 @@
+package xyz.driver.tracing
+package google
+
+import akka.stream.scaladsl._
+import akka.stream._
+import akka.stream.stage._
+import akka.http.scaladsl._
+import akka.http.scaladsl.unmarshalling._
+import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
+import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.headers._
+import akka.util.ByteString
+import java.time._
+import java.nio.file._
+import spray.json._
+import spray.json.DefaultJsonProtocol._
+import pdi.jwt._
+import scala.concurrent._
+import scala.concurrent.duration._
+
+object OAuth2 {
+
+ private case class ServiceAccount(project_id: String,
+ private_key: String,
+ client_email: String)
+ private implicit val serviceAccountFormat = jsonFormat3(ServiceAccount)
+
+ private case class GrantResponse(access_token: String, expires_in: Int)
+ private implicit val grantResponseFormat = jsonFormat2(GrantResponse)
+
+ /** Request a new access token for the given scopes.
+ *
+ * Implements the OAUTH2 workflow as descried here
+ * https://developers.google.com/identity/protocols/OAuth2ServiceAccount
+ */
+ def requestAccessToken(
+ http: HttpExt,
+ serviceAccountFile: Path,
+ scopes: Seq[String]
+ )(implicit ec: ExecutionContext,
+ mat: Materializer): Future[(Instant, String)] =
+ Future {
+ val now = Instant.now.toEpochMilli / 1000
+ val credentials =
+ (new String(Files.readAllBytes(serviceAccountFile), "utf-8")).parseJson
+ .convertTo[ServiceAccount]
+
+ val claim = JwtClaim(
+ issuer = Some(credentials.client_email),
+ expiration = Some(now + 60 * 60),
+ issuedAt = Some(now)
+ ) +
+ ("aud", "https://www.googleapis.com/oauth2/v4/token") +
+ ("scope", scopes.mkString(" "))
+
+ Jwt.encode(claim, credentials.private_key, JwtAlgorithm.RS256)
+ } flatMap { assertion =>
+ http.singleRequest(
+ HttpRequest(
+ method = HttpMethods.POST,
+ uri = "https://www.googleapis.com/oauth2/v4/token"
+ ).withEntity(
+ FormData(
+ "grant_type" -> "urn:ietf:params:oauth:grant-type:jwt-bearer",
+ "assertion" -> assertion
+ ).toEntity))
+ } flatMap { response =>
+ Unmarshal(response).to[GrantResponse]
+ } map { grant =>
+ (Instant.now.plusSeconds(grant.expires_in), grant.access_token)
+ }
+
+ /** Flow that injects access tokens into a stream of HTTP requests.
+ *
+ * Re-authentication happens transparently when access tokens expire. Note:
+ * in case an access token gets revoked, this flow needs to be restarted in
+ * order to re-authenticate
+ */
+ def authenticatedFlow(http: HttpExt,
+ serviceAccountFile: Path,
+ scopes: Seq[String],
+ graceSeconds: Int = 300)(
+ implicit ec: ExecutionContext,
+ mat: Materializer): Flow[HttpRequest, HttpRequest, _] =
+ Flow[HttpRequest]
+ .scanAsync[(HttpRequest, Instant, String)](
+ (HttpRequest(), Instant.now, "")) {
+ case ((_, expiration, accessToken), request) =>
+ if (Instant.now isAfter expiration.minusSeconds(graceSeconds)) {
+ http.system.log.info("tracing access token expired, refreshing")
+ requestAccessToken(http, serviceAccountFile, scopes).map {
+ case (newExpiration, newToken) =>
+ http.system.log.debug("new tracing access token otained")
+ (request, newExpiration, newToken)
+ }
+ } else {
+ Future.successful((request, expiration, accessToken))
+ }
+ }
+ .drop(1) // drop initial element
+ .map {
+ case (request, _, accessToken) =>
+ request.withHeaders(
+ RawHeader("Authorization", "Bearer " + accessToken)
+ )
+ }
+
+}
diff --git a/src/main/scala/google/api.scala b/src/main/scala/google/api.scala
new file mode 100644
index 0000000..122b695
--- /dev/null
+++ b/src/main/scala/google/api.scala
@@ -0,0 +1,113 @@
+package xyz.driver.tracing
+package google
+
+import spray.json._
+import spray.json.DefaultJsonProtocol._
+import java.util.UUID
+import java.nio.ByteBuffer
+import java.time._
+import java.time.format._
+
+case class TraceSpan(
+ spanId: Long,
+ kind: TraceSpan.SpanKind,
+ name: String,
+ startTime: Instant,
+ endTime: Instant,
+ parentSpanId: Option[Long],
+ labels: Map[String, String]
+)
+
+object TraceSpan {
+
+ sealed trait SpanKind
+ // Unspecified
+ case object Unspecified extends SpanKind
+ // Indicates that the span covers server-side handling of an RPC or other remote network request.
+ case object RpcServer extends SpanKind
+ // Indicates that the span covers the client-side wrapper around an RPC or other remote request.
+ case object RpcClient extends SpanKind
+
+ object SpanKind {
+ implicit val format: JsonFormat[SpanKind] = new JsonFormat[SpanKind] {
+ override def write(x: SpanKind): JsValue = x match {
+ case Unspecified => JsString("SPAN_KIND_UNSPECIFIED")
+ case RpcServer => JsString("RPC_SERVER")
+ case RpcClient => JsString("RPC_CLIENT")
+ }
+ override def read(x: JsValue): SpanKind = x match {
+ case JsString("SPAN_KIND_UNSPECIFIED") => Unspecified
+ case JsString("RPC_SERVER") => RpcServer
+ case JsString("RPC_CLIENT") => RpcClient
+ case other =>
+ spray.json.deserializationError(s"`$other` is not a valid span kind")
+ }
+ }
+ }
+
+ implicit val instantFormat = new JsonFormat[Instant] {
+ val formatter = DateTimeFormatter
+ .ofPattern("yyyy-MM-dd'T'HH:mm:ssXXXZ")
+ .withZone(ZoneId.of("UTC"))
+ override def write(x: Instant): JsValue = JsString(formatter.format(x))
+ override def read(x: JsValue): Instant = x match {
+ case JsString(x) => Instant.parse(x)
+ case other =>
+ spray.json.deserializationError(s"`$other` is not a valid instant")
+ }
+ }
+
+ implicit val format: JsonFormat[TraceSpan] = jsonFormat7(TraceSpan.apply)
+
+ def fromSpan(span: Span) = TraceSpan(
+ span.spanId.getLeastSignificantBits,
+ Unspecified,
+ span.name,
+ span.startTime,
+ span.endTime,
+ span.parentSpanId.map(_.getLeastSignificantBits),
+ span.labels
+ )
+
+}
+
+case class Trace(
+ traceId: UUID,
+ projectId: String = "",
+ spans: Seq[TraceSpan] = Seq.empty
+)
+
+object Trace {
+
+ implicit val uuidFormat = new JsonFormat[UUID] {
+ override def write(x: UUID) = {
+ val buffer = ByteBuffer.allocate(16)
+ buffer.putLong(x.getMostSignificantBits)
+ buffer.putLong(x.getLeastSignificantBits)
+ val array = buffer.array()
+ val string = new StringBuilder
+ for (i <- 0 until 16) {
+ string ++= f"${array(i) & 0xff}%02x"
+ }
+ JsString(string.result)
+ }
+ override def read(x: JsValue): UUID = x match {
+ case JsString(str) if str.length == 32 =>
+ val (msb, lsb) = str.splitAt(16)
+ new UUID(java.lang.Long.decode(msb), java.lang.Long.decode(lsb))
+ case JsString(str) =>
+ spray.json.deserializationError(
+ "128-bit id string must be exactly 32 characters long")
+ case other =>
+ spray.json.deserializationError("expected 32 character hex string")
+ }
+ }
+
+ implicit val format: JsonFormat[Trace] = jsonFormat3(Trace.apply)
+
+}
+
+case class Traces(traces: Seq[Trace])
+object Traces {
+ implicit val format: RootJsonFormat[Traces] = jsonFormat1(Traces.apply)
+}