diff options
author | Jakob Odersky <jakob@driver.xyz> | 2017-10-01 20:24:02 -0700 |
---|---|---|
committer | Jakob Odersky <jakob@driver.xyz> | 2017-10-01 20:24:29 -0700 |
commit | 2c08b51411be5b0cce57f876377fcd52bee99990 (patch) | |
tree | fee56a21e6a5f3d2dd459b51e5afb355db6c7f02 /src/main/scala/google | |
parent | 5bd947dd08eec1d6c64a9549566f3ce0e91fe74f (diff) | |
download | tracing-2c08b51411be5b0cce57f876377fcd52bee99990.tar.gz tracing-2c08b51411be5b0cce57f876377fcd52bee99990.tar.bz2 tracing-2c08b51411be5b0cce57f876377fcd52bee99990.zip |
Flatten file hierarchy and implement OAUTH2 authentication
Diffstat (limited to 'src/main/scala/google')
-rw-r--r-- | src/main/scala/google/OAuth2.scala | 108 | ||||
-rw-r--r-- | src/main/scala/google/api.scala | 113 |
2 files changed, 221 insertions, 0 deletions
diff --git a/src/main/scala/google/OAuth2.scala b/src/main/scala/google/OAuth2.scala new file mode 100644 index 0000000..43811c4 --- /dev/null +++ b/src/main/scala/google/OAuth2.scala @@ -0,0 +1,108 @@ +package xyz.driver.tracing +package google + +import akka.stream.scaladsl._ +import akka.stream._ +import akka.stream.stage._ +import akka.http.scaladsl._ +import akka.http.scaladsl.unmarshalling._ +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.util.ByteString +import java.time._ +import java.nio.file._ +import spray.json._ +import spray.json.DefaultJsonProtocol._ +import pdi.jwt._ +import scala.concurrent._ +import scala.concurrent.duration._ + +object OAuth2 { + + private case class ServiceAccount(project_id: String, + private_key: String, + client_email: String) + private implicit val serviceAccountFormat = jsonFormat3(ServiceAccount) + + private case class GrantResponse(access_token: String, expires_in: Int) + private implicit val grantResponseFormat = jsonFormat2(GrantResponse) + + /** Request a new access token for the given scopes. + * + * Implements the OAUTH2 workflow as descried here + * https://developers.google.com/identity/protocols/OAuth2ServiceAccount + */ + def requestAccessToken( + http: HttpExt, + serviceAccountFile: Path, + scopes: Seq[String] + )(implicit ec: ExecutionContext, + mat: Materializer): Future[(Instant, String)] = + Future { + val now = Instant.now.toEpochMilli / 1000 + val credentials = + (new String(Files.readAllBytes(serviceAccountFile), "utf-8")).parseJson + .convertTo[ServiceAccount] + + val claim = JwtClaim( + issuer = Some(credentials.client_email), + expiration = Some(now + 60 * 60), + issuedAt = Some(now) + ) + + ("aud", "https://www.googleapis.com/oauth2/v4/token") + + ("scope", scopes.mkString(" ")) + + Jwt.encode(claim, credentials.private_key, JwtAlgorithm.RS256) + } flatMap { assertion => + http.singleRequest( + HttpRequest( + method = HttpMethods.POST, + uri = "https://www.googleapis.com/oauth2/v4/token" + ).withEntity( + FormData( + "grant_type" -> "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion" -> assertion + ).toEntity)) + } flatMap { response => + Unmarshal(response).to[GrantResponse] + } map { grant => + (Instant.now.plusSeconds(grant.expires_in), grant.access_token) + } + + /** Flow that injects access tokens into a stream of HTTP requests. + * + * Re-authentication happens transparently when access tokens expire. Note: + * in case an access token gets revoked, this flow needs to be restarted in + * order to re-authenticate + */ + def authenticatedFlow(http: HttpExt, + serviceAccountFile: Path, + scopes: Seq[String], + graceSeconds: Int = 300)( + implicit ec: ExecutionContext, + mat: Materializer): Flow[HttpRequest, HttpRequest, _] = + Flow[HttpRequest] + .scanAsync[(HttpRequest, Instant, String)]( + (HttpRequest(), Instant.now, "")) { + case ((_, expiration, accessToken), request) => + if (Instant.now isAfter expiration.minusSeconds(graceSeconds)) { + http.system.log.info("tracing access token expired, refreshing") + requestAccessToken(http, serviceAccountFile, scopes).map { + case (newExpiration, newToken) => + http.system.log.debug("new tracing access token otained") + (request, newExpiration, newToken) + } + } else { + Future.successful((request, expiration, accessToken)) + } + } + .drop(1) // drop initial element + .map { + case (request, _, accessToken) => + request.withHeaders( + RawHeader("Authorization", "Bearer " + accessToken) + ) + } + +} diff --git a/src/main/scala/google/api.scala b/src/main/scala/google/api.scala new file mode 100644 index 0000000..122b695 --- /dev/null +++ b/src/main/scala/google/api.scala @@ -0,0 +1,113 @@ +package xyz.driver.tracing +package google + +import spray.json._ +import spray.json.DefaultJsonProtocol._ +import java.util.UUID +import java.nio.ByteBuffer +import java.time._ +import java.time.format._ + +case class TraceSpan( + spanId: Long, + kind: TraceSpan.SpanKind, + name: String, + startTime: Instant, + endTime: Instant, + parentSpanId: Option[Long], + labels: Map[String, String] +) + +object TraceSpan { + + sealed trait SpanKind + // Unspecified + case object Unspecified extends SpanKind + // Indicates that the span covers server-side handling of an RPC or other remote network request. + case object RpcServer extends SpanKind + // Indicates that the span covers the client-side wrapper around an RPC or other remote request. + case object RpcClient extends SpanKind + + object SpanKind { + implicit val format: JsonFormat[SpanKind] = new JsonFormat[SpanKind] { + override def write(x: SpanKind): JsValue = x match { + case Unspecified => JsString("SPAN_KIND_UNSPECIFIED") + case RpcServer => JsString("RPC_SERVER") + case RpcClient => JsString("RPC_CLIENT") + } + override def read(x: JsValue): SpanKind = x match { + case JsString("SPAN_KIND_UNSPECIFIED") => Unspecified + case JsString("RPC_SERVER") => RpcServer + case JsString("RPC_CLIENT") => RpcClient + case other => + spray.json.deserializationError(s"`$other` is not a valid span kind") + } + } + } + + implicit val instantFormat = new JsonFormat[Instant] { + val formatter = DateTimeFormatter + .ofPattern("yyyy-MM-dd'T'HH:mm:ssXXXZ") + .withZone(ZoneId.of("UTC")) + override def write(x: Instant): JsValue = JsString(formatter.format(x)) + override def read(x: JsValue): Instant = x match { + case JsString(x) => Instant.parse(x) + case other => + spray.json.deserializationError(s"`$other` is not a valid instant") + } + } + + implicit val format: JsonFormat[TraceSpan] = jsonFormat7(TraceSpan.apply) + + def fromSpan(span: Span) = TraceSpan( + span.spanId.getLeastSignificantBits, + Unspecified, + span.name, + span.startTime, + span.endTime, + span.parentSpanId.map(_.getLeastSignificantBits), + span.labels + ) + +} + +case class Trace( + traceId: UUID, + projectId: String = "", + spans: Seq[TraceSpan] = Seq.empty +) + +object Trace { + + implicit val uuidFormat = new JsonFormat[UUID] { + override def write(x: UUID) = { + val buffer = ByteBuffer.allocate(16) + buffer.putLong(x.getMostSignificantBits) + buffer.putLong(x.getLeastSignificantBits) + val array = buffer.array() + val string = new StringBuilder + for (i <- 0 until 16) { + string ++= f"${array(i) & 0xff}%02x" + } + JsString(string.result) + } + override def read(x: JsValue): UUID = x match { + case JsString(str) if str.length == 32 => + val (msb, lsb) = str.splitAt(16) + new UUID(java.lang.Long.decode(msb), java.lang.Long.decode(lsb)) + case JsString(str) => + spray.json.deserializationError( + "128-bit id string must be exactly 32 characters long") + case other => + spray.json.deserializationError("expected 32 character hex string") + } + } + + implicit val format: JsonFormat[Trace] = jsonFormat3(Trace.apply) + +} + +case class Traces(traces: Seq[Trace]) +object Traces { + implicit val format: RootJsonFormat[Traces] = jsonFormat1(Traces.apply) +} |