From 5ec270aa98b806f32338fa25357abdf45dd0625b Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 22 Aug 2018 12:51:36 -0700 Subject: Trait-based initialization and other utilities Adds the concept of a 'platform', a centralized place in which environment-specific information will be managed, and provides common initialization logic for most "standard" apps. As part of the common initialization, other parts of core have also been reworked: - HTTP-related unmarshallers and path matchers have been factored out from core.json to a new core.rest.directives package (core.json extends those unmarshallers and matchers for backwards compatibility) - CORS handling has also been moved to a dedicated utility trait - Some custom headers have been moved from raw headers to typed ones in core.rest.headers - The concept of a "reporter" has been introduced. A reporter is a context-aware combination of tracing and logging. It is intended to issue diagnostic messages that can be traced across service boundaries. Closes #192 Closes #195 --- build.sbt | 71 ++++---- src/main/resources/reference.conf | 31 +--- src/main/scala/xyz/driver/core/Platform.scala | 35 ++++ src/main/scala/xyz/driver/core/Refresh.scala | 3 +- src/main/scala/xyz/driver/core/app/DriverApp.scala | 67 ++----- .../driver/core/discovery/CanDiscoverService.scala | 11 ++ .../scala/xyz/driver/core/init/AkkaBootable.scala | 188 ++++++++++++++++++++ .../xyz/driver/core/init/BuildInfoReflection.scala | 37 ++++ .../scala/xyz/driver/core/init/CloudServices.scala | 88 ++++++++++ src/main/scala/xyz/driver/core/init/HttpApi.scala | 89 ++++++++++ .../scala/xyz/driver/core/init/ProtobufApi.scala | 7 + .../scala/xyz/driver/core/init/SimpleHttpApp.scala | 4 + src/main/scala/xyz/driver/core/json.scala | 80 +-------- .../scala/xyz/driver/core/logging/package.scala | 3 +- .../xyz/driver/core/messaging/GoogleBus.scala | 15 +- .../xyz/driver/core/reporting/GoogleReporter.scala | 193 +++++++++++++++++++++ .../driver/core/reporting/NoTraceReporter.scala | 18 ++ .../scala/xyz/driver/core/reporting/Reporter.scala | 164 +++++++++++++++++ .../driver/core/reporting/ScalaLoggerLike.scala | 31 ++++ .../xyz/driver/core/reporting/SpanContext.scala | 11 ++ .../core/rest/HttpRestServiceTransport.scala | 10 +- src/main/scala/xyz/driver/core/rest/Swagger.scala | 34 ++-- .../core/rest/directives/AuthDirectives.scala | 19 ++ .../core/rest/directives/CorsDirectives.scala | 72 ++++++++ .../driver/core/rest/directives/Directives.scala | 6 + .../driver/core/rest/directives/PathMatchers.scala | 73 ++++++++ .../core/rest/directives/Unmarshallers.scala | 40 +++++ .../xyz/driver/core/rest/headers/Traceparent.scala | 33 ++++ src/main/scala/xyz/driver/core/rest/package.scala | 5 +- .../scala/xyz/driver/core/rest/DriverAppTest.scala | 18 +- 30 files changed, 1223 insertions(+), 233 deletions(-) create mode 100644 src/main/scala/xyz/driver/core/Platform.scala create mode 100644 src/main/scala/xyz/driver/core/discovery/CanDiscoverService.scala create mode 100644 src/main/scala/xyz/driver/core/init/AkkaBootable.scala create mode 100644 src/main/scala/xyz/driver/core/init/BuildInfoReflection.scala create mode 100644 src/main/scala/xyz/driver/core/init/CloudServices.scala create mode 100644 src/main/scala/xyz/driver/core/init/HttpApi.scala create mode 100644 src/main/scala/xyz/driver/core/init/ProtobufApi.scala create mode 100644 src/main/scala/xyz/driver/core/init/SimpleHttpApp.scala create mode 100644 src/main/scala/xyz/driver/core/reporting/GoogleReporter.scala create mode 100644 src/main/scala/xyz/driver/core/reporting/NoTraceReporter.scala create mode 100644 src/main/scala/xyz/driver/core/reporting/Reporter.scala create mode 100644 src/main/scala/xyz/driver/core/reporting/ScalaLoggerLike.scala create mode 100644 src/main/scala/xyz/driver/core/reporting/SpanContext.scala create mode 100644 src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala create mode 100644 src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala create mode 100644 src/main/scala/xyz/driver/core/rest/directives/Directives.scala create mode 100644 src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala create mode 100644 src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala create mode 100644 src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala diff --git a/build.sbt b/build.sbt index 6d6da9e..030b12f 100644 --- a/build.sbt +++ b/build.sbt @@ -7,41 +7,42 @@ lazy val core = project .settings( libraryDependencies ++= Seq( // please keep these sorted alphabetically - "ch.qos.logback" % "logback-classic" % "1.2.3", - "ch.qos.logback.contrib" % "logback-jackson" % "0.1.5", - "ch.qos.logback.contrib" % "logback-json-classic" % "0.1.5", - "com.aliyun.mns" % "aliyun-sdk-mns" % "1.1.8", - "com.aliyun.oss" % "aliyun-sdk-oss" % "2.8.2", - "com.amazonaws" % "aws-java-sdk-s3" % "1.11.342", - "com.beachape" %% "enumeratum" % "1.5.13", - "com.github.swagger-akka-http" %% "swagger-akka-http" % "0.14.1", - "com.google.cloud" % "google-cloud-pubsub" % "1.31.0", - "com.google.cloud" % "google-cloud-storage" % "1.31.0", - "com.googlecode.libphonenumber" % "libphonenumber" % "8.9.7", - "com.neovisionaries" % "nv-i18n" % "1.23", - "com.pauldijou" %% "jwt-core" % "0.16.0", - "com.softwaremill.sttp" %% "akka-http-backend" % "1.2.2", - "com.softwaremill.sttp" %% "core" % "1.2.2", - "com.typesafe" % "config" % "1.3.3", - "com.typesafe.akka" %% "akka-actor" % "2.5.14", - "com.typesafe.akka" %% "akka-http-core" % "10.1.4", - "com.typesafe.akka" %% "akka-http-spray-json" % "10.1.4", - "com.typesafe.akka" %% "akka-http-testkit" % "10.1.4", - "com.typesafe.akka" %% "akka-stream" % "2.5.14", - "com.typesafe.scala-logging" %% "scala-logging" % "3.9.0", - "com.typesafe.slick" %% "slick" % "3.2.3", - "eu.timepit" %% "refined" % "0.9.0", - "io.kamon" %% "kamon-akka-2.5" % "1.0.0", - "io.kamon" %% "kamon-core" % "1.1.3", - "io.kamon" %% "kamon-statsd" % "1.0.0", - "io.kamon" %% "kamon-system-metrics" % "1.0.0", - "javax.xml.bind" % "jaxb-api" % "2.2.8", - "org.mockito" % "mockito-core" % "1.9.5" % "test", - "org.scala-lang.modules" %% "scala-async" % "0.9.7", - "org.scalacheck" %% "scalacheck" % "1.14.0" % "test", - "org.scalatest" %% "scalatest" % "3.0.5" % "test", - "org.scalaz" %% "scalaz-core" % "7.2.24", - "xyz.driver" %% "tracing" % "0.1.2" + "ch.qos.logback" % "logback-classic" % "1.2.3", + "ch.qos.logback.contrib" % "logback-jackson" % "0.1.5", + "ch.qos.logback.contrib" % "logback-json-classic" % "0.1.5", + "com.aliyun.mns" % "aliyun-sdk-mns" % "1.1.8", + "com.aliyun.oss" % "aliyun-sdk-oss" % "2.8.2", + "com.amazonaws" % "aws-java-sdk-s3" % "1.11.342", + "com.beachape" %% "enumeratum" % "1.5.13", + "com.github.swagger-akka-http" %% "swagger-akka-http" % "1.0.0", + "com.google.cloud" % "google-cloud-pubsub" % "1.31.0", + "com.google.cloud" % "google-cloud-storage" % "1.31.0", + "com.googlecode.libphonenumber" % "libphonenumber" % "8.9.7", + "com.neovisionaries" % "nv-i18n" % "1.23", + "com.pauldijou" %% "jwt-core" % "0.16.0", + "com.softwaremill.sttp" %% "akka-http-backend" % "1.2.2", + "com.softwaremill.sttp" %% "core" % "1.2.2", + "com.typesafe" % "config" % "1.3.3", + "com.typesafe.akka" %% "akka-actor" % "2.5.14", + "com.typesafe.akka" %% "akka-http-core" % "10.1.4", + "com.typesafe.akka" %% "akka-http-spray-json" % "10.1.4", + "com.typesafe.akka" %% "akka-http-testkit" % "10.1.4", + "com.typesafe.akka" %% "akka-stream" % "2.5.14", + "com.typesafe.scala-logging" %% "scala-logging" % "3.9.0", + "com.typesafe.slick" %% "slick" % "3.2.3", + "eu.timepit" %% "refined" % "0.9.0", + "io.kamon" %% "kamon-akka-2.5" % "1.0.0", + "io.kamon" %% "kamon-core" % "1.1.3", + "io.kamon" %% "kamon-statsd" % "1.0.0", + "io.kamon" %% "kamon-system-metrics" % "1.0.0", + "javax.xml.bind" % "jaxb-api" % "2.2.8", + "org.mockito" % "mockito-core" % "1.9.5" % "test", + "org.scala-lang.modules" %% "scala-async" % "0.9.7", + "org.scalacheck" %% "scalacheck" % "1.14.0" % "test", + "org.scalatest" %% "scalatest" % "3.0.5" % "test", + "org.scalaz" %% "scalaz-core" % "7.2.24", + "xyz.driver" %% "spray-json-derivation" % "0.6.0", + "xyz.driver" %% "tracing" % "0.1.2" ), scalacOptions in (Compile, doc) ++= Seq( "-groups", // group similar methods together based on the @group annotation. diff --git a/src/main/resources/reference.conf b/src/main/resources/reference.conf index 795dbb4..608262a 100644 --- a/src/main/resources/reference.conf +++ b/src/main/resources/reference.conf @@ -11,31 +11,12 @@ application { baseUrl: "localhost:8080" environment: "local_testing" - cors { - allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"] - allowedOrigins: [ - { - scheme: http - hostSuffix: localhost - }, - { - scheme: https - hostSuffix: driver.xyz - }, - { - scheme: http - hostSuffix: dev.cndriver.xyz - }, - { - scheme: https - hostSuffix: driver.network - }, - { - scheme: https - hostSuffix: cndriver.xyz - } - ] - } + cors.allowedOrigins: [ + "localhost", + "driver.xyz", + "driver.network", + "cndriver.xyz" + ] } # Settings about the auto-generated REST API documentation. diff --git a/src/main/scala/xyz/driver/core/Platform.scala b/src/main/scala/xyz/driver/core/Platform.scala new file mode 100644 index 0000000..aa7e711 --- /dev/null +++ b/src/main/scala/xyz/driver/core/Platform.scala @@ -0,0 +1,35 @@ +package xyz.driver.core +import java.nio.file.{Files, Path, Paths} + +import com.google.auth.oauth2.ServiceAccountCredentials + +sealed trait Platform { + def isKubernetes: Boolean +} + +object Platform { + case class GoogleCloud(keyfile: Path, namespace: String) extends Platform { + def credentials: ServiceAccountCredentials = ServiceAccountCredentials.fromStream( + Files.newInputStream(keyfile) + ) + def project: String = credentials.getProjectId + override def isKubernetes = true + } + // case object AliCloud extends Platform + case object Dev extends Platform { + override def isKubernetes: Boolean = false + } + + lazy val fromEnv: Platform = { + def isGoogle = sys.env.get("GOOGLE_APPLICATION_CREDENTIALS").map { value => + val keyfile = Paths.get(value) + require(Files.isReadable(keyfile), s"Google credentials file $value is not readable.") + val namespace = sys.env.getOrElse("SERVICE_NAMESPACE", sys.error("Namespace not set")) + GoogleCloud(keyfile, namespace) + } + isGoogle.getOrElse(Dev) + } + + def current: Platform = fromEnv + +} diff --git a/src/main/scala/xyz/driver/core/Refresh.scala b/src/main/scala/xyz/driver/core/Refresh.scala index e66b22f..6db9c26 100644 --- a/src/main/scala/xyz/driver/core/Refresh.scala +++ b/src/main/scala/xyz/driver/core/Refresh.scala @@ -9,7 +9,8 @@ import scala.concurrent.duration.Duration /** A single-value asynchronous cache with TTL. * * Slightly adapted from - * [[https://github.com/twitter/util/blob/ae0ab09134414438af9dfaa88a4613cecbff4741/util-cache/src/main/scala/com/twitter/cache/Refresh.scala Twitter's "util" library]] + * [[https://github.com/twitter/util/blob/ae0ab09134414438af9dfaa88a4613cecbff4741/util-cache/src/main/scala/com/twitter/cache/Refresh.scala + * Twitter's "util" library]] * * Released under the Apache License 2.0. */ diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 9214cea..50e471c 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -11,7 +11,6 @@ import akka.http.scaladsl.{Http, HttpExt} import akka.stream.ActorMaterializer import com.typesafe.config.Config import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme import kamon.Kamon import kamon.statsd.StatsDReporter import kamon.system.SystemMetrics @@ -24,11 +23,13 @@ import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} import xyz.driver.tracing.TracingDirectives._ import xyz.driver.tracing._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext} import scala.util.Try import scalaz.Scalaz.stringInstance import scalaz.syntax.equal._ +import xyz.driver.core.rest.directives.CorsDirectives class DriverApp( appName: String, @@ -69,62 +70,15 @@ class DriverApp( } } - protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { - import scala.collection.JavaConverters._ - config - .getConfigList("application.cors.allowedOrigins") - .asScala - .map { c => - HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) - }(scala.collection.breakOut) - } - - protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { - import scala.collection.JavaConverters._ - config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) - } - - protected lazy val defaultCorsAllowedOrigin: Origin = { - Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) - } - - protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { - val allowedOrigin = - origin - .filter { requestOrigin => - allowedCorsDomainSuffixes.exists { allowedOriginSuffix => - requestOrigin.origins.exists(o => - o.scheme == allowedOriginSuffix.scheme && - o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) - } - } - .getOrElse(defaultCorsAllowedOrigin) - - `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) - } - - protected def respondWithAllCorsHeaders: Directive0 = { - respondWithCorsAllowedHeaders tflatMap { _ => - respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => - optionalHeaderValueByType[Origin](()) flatMap { origin => - respondWithHeader(corsAllowedOriginHeader(origin)) - } - } - } - } - private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def defaultOptionsRoute: Route = options { - respondWithAllCorsHeaders { - complete("OK") - } - } - def appRoute: Route = { - val serviceTypes = modules.flatMap(_.routeTypes) - val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) + val serviceTypes = modules.flatMap(_.routeTypes) + val serviceClasses = serviceTypes.map { tpe => + scala.reflect.runtime.currentMirror.runtimeClass(tpe.typeSymbol.asClass) + }.toSet + val swaggerService = new Swagger(baseUrl, scheme :: Nil, version, serviceClasses, config, log) val swaggerRoute = new DriverRoute { def log: Logger = self.log def route: Route = handleExceptions(ExceptionHandler(exceptionHandler)) { @@ -139,7 +93,6 @@ class DriverApp( val combinedRoute = Route.seal( modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ swaggerRoute.route ~ defaultOptionsRoute) - (extractHost & extractClientIP & trace(tracer) & handleRejections(authenticationRejectionHandler)) { case (origin, ip) => ctx => @@ -173,7 +126,11 @@ class DriverApp( r } - respondWithAllCorsHeaders(logResponses(combinedRoute))(contextWithTrackingId) + val innerRoute = CorsDirectives.cors( + config.getStringList("application.cors.allowedOrigins").asScala.toSet, + AllowedHeaders + )(logResponses(combinedRoute)) + innerRoute(contextWithTrackingId) } } diff --git a/src/main/scala/xyz/driver/core/discovery/CanDiscoverService.scala b/src/main/scala/xyz/driver/core/discovery/CanDiscoverService.scala new file mode 100644 index 0000000..8711332 --- /dev/null +++ b/src/main/scala/xyz/driver/core/discovery/CanDiscoverService.scala @@ -0,0 +1,11 @@ +package xyz.driver.core +package discovery + +import scala.annotation.implicitNotFound + +@implicitNotFound( + "Don't know how to communicate with service ${Service}. Make sure an implicit CanDiscoverService is" + + "available. A good place to put one is in the service's companion object.") +trait CanDiscoverService[Service] { + def discover(platform: Platform): Service +} diff --git a/src/main/scala/xyz/driver/core/init/AkkaBootable.scala b/src/main/scala/xyz/driver/core/init/AkkaBootable.scala new file mode 100644 index 0000000..6a28fe8 --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/AkkaBootable.scala @@ -0,0 +1,188 @@ +package xyz.driver.core +package init + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.{RequestContext, Route} +import akka.stream.scaladsl.Source +import akka.stream.{ActorMaterializer, Materializer} +import akka.util.ByteString +import com.softwaremill.sttp.SttpBackend +import com.softwaremill.sttp.akkahttp.AkkaHttpBackend +import com.typesafe.config.Config +import kamon.Kamon +import kamon.statsd.StatsDReporter +import kamon.system.SystemMetrics +import xyz.driver.core.reporting.{NoTraceReporter, Reporter, ScalaLoggerLike, SpanContext} +import xyz.driver.core.rest.HttpRestServiceTransport + +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} + +/** Provides standard scaffolding for applications that use Akka HTTP. + * + * Among the features provided are: + * + * - execution contexts of various kinds + * - basic JVM metrics collection via Kamon + * - startup and shutdown hooks + * + * This trait provides a minimal, runnable application. It is designed to be extended by various mixins (see + * Known Subclasses) in this package. + * + * By implementing a "main" method, mixing this trait into a singleton object will result in a runnable + * application. + * I.e. + * {{{ + * object Main extends AkkaBootable // this is a runnable application + * }}} + * In case this trait isn't mixed into a top-level singleton object, the [[AkkaBootable#main main]] method should + * be called explicitly, in order to initialize and start this application. + * I.e. + * {{{ + * object Main { + * val bootable = new AkkaBootable {} + * def main(args: Array[String]): Unit = { + * bootable.main(args) + * } + * } + * }}} + * + * @groupname config Configuration + * @groupname contexts Contexts + * @groupname utilities Utilities + * @groupname hooks Overrideable Hooks + */ +trait AkkaBootable { + + /** The application's name. This value is extracted from the build configuration. + * @group config + */ + def name: String = BuildInfoReflection.name + + /** The application's version (or git sha). This value is extracted from the build configuration. + * @group config + */ + def version: Option[String] = BuildInfoReflection.version + + /** TCP port that this application will listen on. + * @group config + */ + def port: Int = 8080 + + // contexts + /** General-purpose actor system for this application. + * @group contexts + */ + implicit lazy val system: ActorSystem = ActorSystem("app") + + /** General-purpose stream materializer for this application. + * @group contexts + */ + implicit lazy val materializer: Materializer = ActorMaterializer() + + /** General-purpose execution context for this application. + * + * Note that no thread-blocking tasks should be submitted to this context. In cases that do require blocking, + * a custom execution context should be defined and used. See + * [[https://doc.akka.io/docs/akka-http/current/handling-blocking-operations-in-akka-http-routes.html this guide]] + * on how to configure custom execution contexts in Akka. + * + * @group contexts + */ + implicit lazy val executionContext: ExecutionContext = system.dispatcher + + /** Default HTTP client, backed by this application's actor system. + * @group contexts + */ + implicit lazy val httpClient: SttpBackend[Future, Source[ByteString, Any]] = AkkaHttpBackend.usingActorSystem(system) + + /** Old HTTP client system. Prefer using an sttp backend for new service clients. + * @group contexts + * @see httpClient + */ + implicit lazy val clientTransport: HttpRestServiceTransport = new HttpRestServiceTransport( + applicationName = Name(name), + applicationVersion = version.getOrElse(""), + actorSystem = system, + executionContext = executionContext, + log = reporter + ) + + // utilities + /** Default reporter instance. + * + * Note that this is currently defined to be a ScalaLoggerLike, so that it can be implicitly converted to a + * [[com.typesafe.scalalogging.Logger]] when necessary. This conversion is provided to ensure backwards + * compatibility with code that requires such a logger. Warning: using a logger instead of a reporter will + * not include tracing information in any messages! + * + * @group utilities + */ + def reporter: Reporter with ScalaLoggerLike = new NoTraceReporter(ScalaLoggerLike.defaultScalaLogger(json = false)) + + /** Top-level application configuration. + * + * TODO: should we expose some config wrapper rather than the typesafe config library? + * (Author's note: I'm a fan of TOML since it's so simple. There's already an implementation for Scala + * [[https://github.com/jvican/stoml]].) + * + * @group utilities + */ + def config: Config = system.settings.config + + /** Overridable startup hook. + * + * Invoked by [[main]] during application startup. + * + * @group hooks + */ + def startup(): Unit = () + + /** Overridable shutdown hook. + * + * Invoked on an arbitrary thread when a shutdown signal is caught. + * + * @group hooks + */ + def shutdown(): Unit = () + + /** Overridable HTTP route. + * + * Any services that present an HTTP interface should implement this method. + * + * @group hooks + * @see [[HttpApi]] + */ + def route: Route = (ctx: RequestContext) => ctx.complete(StatusCodes.NotFound) + + private def syslog(message: String)(implicit ctx: SpanContext) = reporter.info(s"application: " + message) + + /** This application's entry point. */ + def main(args: Array[String]): Unit = { + implicit val ctx = SpanContext.fresh() + syslog("initializing metrics collection") + Kamon.addReporter(new StatsDReporter()) + SystemMetrics.startCollecting() + + system.registerOnTermination { + syslog("running shutdown hooks") + shutdown() + syslog("bye!") + } + + syslog("running startup hooks") + startup() + + syslog("binding to network interface") + val binding = Await.result( + Http().bindAndHandle(route, "::", port), + 2.seconds + ) + syslog(s"listening to ${binding.localAddress}") + + syslog("startup complete") + } + +} diff --git a/src/main/scala/xyz/driver/core/init/BuildInfoReflection.scala b/src/main/scala/xyz/driver/core/init/BuildInfoReflection.scala new file mode 100644 index 0000000..0e53085 --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/BuildInfoReflection.scala @@ -0,0 +1,37 @@ +package xyz.driver.core +package init + +import scala.reflect.runtime +import scala.util.Try +import scala.util.control.NonFatal + +/** Utility object to retrieve fields from static build configuration objects. */ +private[init] object BuildInfoReflection { + + final val BuildInfoName = "xyz.driver.BuildInfo" + + lazy val name: String = get[String]("name") + lazy val version: Option[String] = find[String]("version") + + /** Lookup a given field in the build configuration. This field is required to exist. */ + private def get[A](fieldName: String): A = + try { + val mirror = runtime.currentMirror + val module = mirror.staticModule(BuildInfoName) + val instance = mirror.reflectModule(module).instance + val accessor = module.info.decl(mirror.universe.TermName(fieldName)).asMethod + mirror.reflect(instance).reflectMethod(accessor).apply().asInstanceOf[A] + } catch { + case NonFatal(err) => + throw new RuntimeException( + s"Cannot find field name '$fieldName' in $BuildInfoName. Please define (or generate) a singleton " + + s"object with that field. Alternatively, in order to avoid runtime reflection, you may override the " + + s"caller with a static value.", + err + ) + } + + /** Try finding a given field in the build configuration. If the field does not exist, None is returned. */ + private def find[A](fieldName: String): Option[A] = Try { get[A](fieldName) }.toOption + +} diff --git a/src/main/scala/xyz/driver/core/init/CloudServices.scala b/src/main/scala/xyz/driver/core/init/CloudServices.scala new file mode 100644 index 0000000..9f4ab5c --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/CloudServices.scala @@ -0,0 +1,88 @@ +package xyz.driver.core +package init + +import java.nio.file.Paths + +import xyz.driver.core.discovery.CanDiscoverService +import xyz.driver.core.messaging.{GoogleBus, QueueBus, StreamBus} +import xyz.driver.core.reporting._ +import xyz.driver.core.reporting.ScalaLoggerLike.defaultScalaLogger +import xyz.driver.core.storage.{BlobStorage, FileSystemBlobStorage, GcsBlobStorage} + +import scala.concurrent.ExecutionContext + +/** Mixin trait that provides essential cloud utilities. */ +trait CloudServices extends AkkaBootable { self => + + /** The platform that this application is running on. + * @group config + */ + def platform: Platform = Platform.current + + /** Service discovery for the current platform. + * + * Define a service trait and companion object: + * {{{ + * trait MyService { + * def call(): Int + * } + * object MyService { + * implicit val isDiscoverable = new xyz.driver.core.discovery.CanDiscoverService[MyService] { + * def discover(p: xyz.driver.core.Platform): MyService = new MyService { + * def call() = 42 + * } + * } + * } + * }}} + * + * Then discover and use it: + * {{{ + * discover[MyService].call() + * }}} + * + * @group utilities + */ + def discover[A](implicit cds: CanDiscoverService[A]): A = cds.discover(platform) + + /* TODO: this reporter uses the platform to determine if JSON logging should be enabled. + * Since the default logger uses slf4j, its settings must be specified before a logger + * is first accessed. This in turn leads to somewhat convoluted code, + * since we can't log when the platform being is determined. + * A potential fix would be to make the log format independent of the platform, and always log + * as JSON for example. + */ + override lazy val reporter: Reporter with ScalaLoggerLike = { + Console.println("determining platform") // scalastyle:ignore + val r = platform match { + case p @ Platform.GoogleCloud(_, _) => + new GoogleReporter(p.credentials, p.namespace, defaultScalaLogger(true)) + case Platform.Dev => + new NoTraceReporter(defaultScalaLogger(false)) + } + r.info(s"application started on platform '${platform}'")(SpanContext.fresh()) + r + } + + /** Object storage. + * @group utilities + */ + def storage(bucketId: String): BlobStorage = + platform match { + case Platform.GoogleCloud(keyfile, _) => + GcsBlobStorage.fromKeyfile(keyfile, bucketId) + case Platform.Dev => + new FileSystemBlobStorage(Paths.get(s".data-$bucketId")) + } + + /** Message bus. + * @group utilities + */ + def messageBus: StreamBus = platform match { + case Platform.GoogleCloud(keyfile, namespace) => GoogleBus.fromKeyfile(keyfile, namespace) + case Platform.Dev => + new QueueBus()(self.system) with StreamBus { + override def executionContext: ExecutionContext = self.executionContext + } + } + +} diff --git a/src/main/scala/xyz/driver/core/init/HttpApi.scala b/src/main/scala/xyz/driver/core/init/HttpApi.scala new file mode 100644 index 0000000..6ea3d51 --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/HttpApi.scala @@ -0,0 +1,89 @@ +package xyz.driver.core +package init + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.server.{RequestContext, Route, RouteConcatenation} +import spray.json.DefaultJsonProtocol._ +import spray.json._ +import xyz.driver.core.rest.Swagger +import xyz.driver.core.rest.directives.Directives +import akka.http.scaladsl.model.headers._ +import xyz.driver.core.reporting.Reporter.CausalRelation +import xyz.driver.core.reporting.SpanContext +import xyz.driver.core.rest.headers.Traceparent + +import scala.collection.JavaConverters._ + +/** Mixin trait that provides some well-known HTTP endpoints, diagnostic header injection and forwarding, + * and exposes an application-specific route that must be implemented by services. + * @see ProtobufApi + */ +trait HttpApi extends CloudServices with Directives with SprayJsonSupport { self => + + /** Route that handles the application's business logic. + * @group hooks + */ + def applicationRoute: Route + + /** Classes with Swagger annotations. + * @group hooks + */ + def swaggerRouteClasses: Set[Class[_]] + + private val healthRoute = path("health") { + complete(Map("status" -> "good").toJson) + } + + private val versionRoute = path("version") { + complete(Map("name" -> self.name.toJson, "version" -> self.version.toJson).toJson) + } + + private lazy val swaggerRoute = { + val generator = new Swagger( + "", + "https" :: "http" :: Nil, + self.version.getOrElse(""), + swaggerRouteClasses, + config, + reporter + ) + generator.routes ~ generator.swaggerUI + } + + private def cors(inner: Route): Route = + cors( + config.getStringList("application.cors.allowedOrigins").asScala.toSet, + xyz.driver.core.rest.AllowedHeaders + )(inner) + + private def traced(inner: Route): Route = (ctx: RequestContext) => { + val tags = Map( + "service_name" -> name, + "service_version" -> version.getOrElse(""), + "http_user_agent" -> ctx.request.header[`User-Agent`].map(_.value).getOrElse(""), + "http_uri" -> ctx.request.uri.toString, + "http_path" -> ctx.request.uri.path.toString + ) + val parent = ctx.request.header[Traceparent].map { p => + SpanContext(p.traceId, p.spanId) -> CausalRelation.Child + } + reporter.traceWithOptionalParentAsync("handle_service_request", tags, parent) { sctx => + val header = Traceparent(sctx.traceId, sctx.spanId) + val withHeader = ctx.withRequest(ctx.request.withHeaders(header)) + inner(withHeader) + } + } + + /** Extended route. */ + override lazy val route: Route = traced( + cors( + RouteConcatenation.concat( + healthRoute, + versionRoute, + swaggerRoute, + applicationRoute + ) + ) + ) + +} diff --git a/src/main/scala/xyz/driver/core/init/ProtobufApi.scala b/src/main/scala/xyz/driver/core/init/ProtobufApi.scala new file mode 100644 index 0000000..284ac67 --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/ProtobufApi.scala @@ -0,0 +1,7 @@ +package xyz.driver.core +package init + +/** Mixin trait for services that implement an API based on Protocol Buffers and gRPC. + * TODO: implement + */ +trait ProtobufApi extends AkkaBootable diff --git a/src/main/scala/xyz/driver/core/init/SimpleHttpApp.scala b/src/main/scala/xyz/driver/core/init/SimpleHttpApp.scala new file mode 100644 index 0000000..61ca363 --- /dev/null +++ b/src/main/scala/xyz/driver/core/init/SimpleHttpApp.scala @@ -0,0 +1,4 @@ +package xyz.driver.core +package init + +trait SimpleHttpApp extends AkkaBootable with HttpApi with CloudServices diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala index d9319e9..4daf127 100644 --- a/src/main/scala/xyz/driver/core/json.scala +++ b/src/main/scala/xyz/driver/core/json.scala @@ -5,10 +5,6 @@ import java.time.format.DateTimeFormatter import java.time.{Instant, LocalDate} import java.util.{TimeZone, UUID} -import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} -import akka.http.scaladsl.model.Uri.Path -import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} -import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshaller import com.neovisionaries.i18n.{CountryCode, CurrencyCode} import enumeratum._ @@ -19,6 +15,7 @@ import spray.json._ import xyz.driver.core.auth.AuthCredentials import xyz.driver.core.date.{Date, DayOfWeek, Month} import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} import xyz.driver.core.rest.errors._ import xyz.driver.core.time.{Time, TimeOfDay} @@ -27,25 +24,9 @@ import scala.reflect.{ClassTag, classTag} import scala.util.Try import scala.util.control.NonFatal -object json { +object json extends PathMatchers with Unmarshallers { import DefaultJsonProtocol._ - private def UuidInPath[T]: PathMatcher1[Id[T]] = - PathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) - - def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) - case _ => Unmatched - } - } - - implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = - Unmarshaller.firstOf( - Unmarshaller.strict((JsString(_: String)) andThen reader.read), - stringToValueUnmarshaller[T] - ) - implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { def write(id: Id[T]) = JsString(id.value) @@ -67,13 +48,6 @@ object json { override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) } - def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) - case _ => Unmatched - } - } - implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { def write(name: Name[T]) = JsString(name.value) @@ -83,26 +57,6 @@ object json { } } - def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) - - private def timestampInPath: PathMatcher1[Long] = - PathMatcher("""[+-]?\d*""".r) flatMap { string => - try Some(string.toLong) - catch { case _: IllegalArgumentException => None } - } - - def InstantInPath: PathMatcher1[Instant] = - new PathMatcher1[Instant] { - def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { - case Path.Segment(head, tail) => - try Matched(tail, Tuple1(Instant.parse(head))) - catch { - case NonFatal(_) => Unmatched - } - case _ => Unmatched - } - } | timestampInPath.map(Instant.ofEpochMilli) - implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) @@ -194,14 +148,6 @@ object json { } } - def RevisionInPath[T]: PathMatcher1[Revision[T]] = - PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => - Some(Revision[T](string)) - } - - implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = - Unmarshaller.strict[String, Revision[T]](Revision[T]) - implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { def write(revision: Revision[T]) = JsString(revision.id.toString) @@ -413,17 +359,6 @@ object json { } } - def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - refineV[NonEmpty](segment) match { - case Left(_) => Unmatched - case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) - } - case _ => Unmatched - } - } - implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = new RootJsonFormat[NonEmptyName[T]] { def write(name: NonEmptyName[T]) = JsString(name.value.value) @@ -452,15 +387,4 @@ object json { case "DatabaseException" => jsonFormat(DatabaseException, "message") } - val jsValueToStringMarshaller: Marshaller[JsValue, String] = - Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) - - def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = - jsValueToStringMarshaller.compose[T](jsonFormat.write) - - val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = - Unmarshaller.strict[String, JsValue](value => value.parseJson) - - def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = - stringToJsValueUnmarshaller.map[T](jsonFormat.read) } diff --git a/src/main/scala/xyz/driver/core/logging/package.scala b/src/main/scala/xyz/driver/core/logging/package.scala index 2b6fc11..a4d01fa 100644 --- a/src/main/scala/xyz/driver/core/logging/package.scala +++ b/src/main/scala/xyz/driver/core/logging/package.scala @@ -1,7 +1,8 @@ package xyz.driver.core +import com.typesafe.scalalogging.Logger import org.slf4j.helpers.NOPLogger package object logging { - val NoLogger = com.typesafe.scalalogging.Logger(NOPLogger.NOP_LOGGER) + val NoLogger: Logger = Logger.apply(NOPLogger.NOP_LOGGER) } diff --git a/src/main/scala/xyz/driver/core/messaging/GoogleBus.scala b/src/main/scala/xyz/driver/core/messaging/GoogleBus.scala index 9895708..b296c50 100644 --- a/src/main/scala/xyz/driver/core/messaging/GoogleBus.scala +++ b/src/main/scala/xyz/driver/core/messaging/GoogleBus.scala @@ -2,7 +2,7 @@ package xyz.driver.core package messaging import java.nio.ByteBuffer -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} import java.security.Signature import java.time.Instant import java.util @@ -245,14 +245,23 @@ object GoogleBus { implicit val subscrptionPullFormat: RootJsonFormat[SubscriptionPull] = jsonFormat1(SubscriptionPull) } + def fromKeyfile(keyfile: Path, namespace: String)( + implicit executionContext: ExecutionContext, + backend: SttpBackend[Future, _]): GoogleBus = { + val creds = ServiceAccountCredentials.fromStream(Files.newInputStream(keyfile)) + new GoogleBus(creds, namespace) + } + + @deprecated( + "Reading from the environment adds opaque dependencies and hance leads to extra complexity. Use fromKeyfile instead.", + "driver-core 1.12.2") def fromEnv(implicit executionContext: ExecutionContext, backend: SttpBackend[Future, _]): GoogleBus = { def env(key: String) = { require(sys.env.contains(key), s"Environment variable $key is not set.") sys.env(key) } val keyfile = Paths.get(env("GOOGLE_APPLICATION_CREDENTIALS")) - val creds = ServiceAccountCredentials.fromStream(Files.newInputStream(keyfile)) - new GoogleBus(creds, env("SERVICE_NAMESPACE")) + fromKeyfile(keyfile, env("SERVICE_NAMESPACE")) } } diff --git a/src/main/scala/xyz/driver/core/reporting/GoogleReporter.scala b/src/main/scala/xyz/driver/core/reporting/GoogleReporter.scala new file mode 100644 index 0000000..40cb1e5 --- /dev/null +++ b/src/main/scala/xyz/driver/core/reporting/GoogleReporter.scala @@ -0,0 +1,193 @@ +package xyz.driver.core +package reporting +import java.security.Signature +import java.time.Instant +import java.util + +import akka.NotUsed +import akka.stream.scaladsl.{Flow, RestartSink, Sink, Source, SourceQueueWithComplete} +import akka.stream.{Materializer, OverflowStrategy} +import com.google.auth.oauth2.ServiceAccountCredentials +import com.softwaremill.sttp._ +import com.typesafe.scalalogging.Logger +import spray.json.DerivedJsonProtocol._ +import spray.json._ +import xyz.driver.core.reporting.Reporter.CausalRelation + +import scala.async.Async._ +import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} +import scala.util.Random +import scala.util.control.NonFatal + +/** A reporter that collects traces and submits them to + * [[https://cloud.google.com/trace/docs/reference/v2/rest/ Google's Stackdriver Trace API]]. + */ +class GoogleReporter( + credentials: ServiceAccountCredentials, + namespace: String, + val logger: Logger, + buffer: Int = GoogleReporter.DefaultBufferSize, + interval: FiniteDuration = GoogleReporter.DefaultInterval)( + implicit client: SttpBackend[Future, _], + mat: Materializer, + ec: ExecutionContext +) extends Reporter with ScalaLoggerLike { + import GoogleReporter._ + + private val getToken: () => Future[String] = Refresh.every(55.minutes) { + def jwt = { + val now = Instant.now().getEpochSecond + val base64 = util.Base64.getEncoder + val header = base64.encodeToString("""{"alg":"RS256","typ":"JWT"}""".getBytes("utf-8")) + val body = base64.encodeToString( + s"""|{ + | "iss": "${credentials.getClientEmail}", + | "scope": "https://www.googleapis.com/auth/trace.append", + | "aud": "https://www.googleapis.com/oauth2/v4/token", + | "exp": ${now + 60.minutes.toSeconds}, + | "iat": $now + |}""".stripMargin.getBytes("utf-8") + ) + val signer = Signature.getInstance("SHA256withRSA") + signer.initSign(credentials.getPrivateKey) + signer.update(s"$header.$body".getBytes("utf-8")) + val signature = base64.encodeToString(signer.sign()) + s"$header.$body.$signature" + } + sttp + .post(uri"https://www.googleapis.com/oauth2/v4/token") + .body( + "grant_type" -> "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion" -> jwt + ) + .mapResponse(s => s.parseJson.asJsObject.fields("access_token").convertTo[String]) + .send() + .map(_.unsafeBody) + } + + private val sendToGoogle: Sink[Span, NotUsed] = RestartSink.withBackoff( + minBackoff = 3.seconds, + maxBackoff = 30.seconds, + randomFactor = 0.2 // adds 20% "noise" to vary the intervals slightly + ) { () => + Flow[Span] + .groupedWithin(buffer, interval) + .mapAsync(1) { spans => + async { + val token = await(getToken()) + val res = await( + sttp + .post(uri"https://cloudtrace.googleapis.com/v2/projects/${credentials.getProjectId}/traces:batchWrite") + .auth + .bearer(token) + .body( + Spans(spans).toJson.compactPrint + ) + .send() + .map(_.unsafeBody)) + res + } + } + .recover { + case NonFatal(e) => + System.err.println(s"Error submitting trace spans: $e") // scalastyle: ignore + throw e + } + .to(Sink.ignore) + } + private val queue: SourceQueueWithComplete[Span] = Source + .queue[Span](buffer, OverflowStrategy.dropHead) + .to(sendToGoogle) + .run() + + private def submit(span: Span): Unit = queue.offer(span).failed.map { e => + System.err.println(s"Error adding span to submission queue: $e") + } + + private def startSpan( + traceId: String, + spanId: String, + parentSpanId: Option[String], + displayName: String, + attributes: Map[String, String]) = Span( + s"project/${credentials.getProjectId}/traces/$traceId/spans/$spanId", + spanId, + parentSpanId, + TruncatableString(displayName), + Instant.now(), + Instant.now(), + Attributes(attributes ++ Map("namespace" -> namespace)) + ) + + def traceWithOptionalParent[A]( + operationName: String, + tags: Map[String, String], + parent: Option[(SpanContext, CausalRelation)])(operation: SpanContext => A): A = { + val child = parent match { + case Some((p, _)) => SpanContext(p.traceId, f"${Random.nextLong()}%02x") + case None => SpanContext.fresh() + } + val span = startSpan(child.traceId, child.spanId, parent.map(_._1.spanId), operationName, tags) + val result = operation(child) + span.endTime = Instant.now() + submit(span) + result + } + + def traceWithOptionalParentAsync[A]( + operationName: String, + tags: Map[String, String], + parent: Option[(SpanContext, CausalRelation)])(operation: SpanContext => Future[A]): Future[A] = { + val child = parent match { + case Some((p, _)) => SpanContext(p.traceId, f"${Random.nextLong()}%02x") + case None => SpanContext.fresh() + } + val span = startSpan(child.traceId, child.spanId, parent.map(_._1.spanId), operationName, tags) + val result = operation(child) + result.onComplete { _ => + span.endTime = Instant.now() + submit(span) + } + result + } +} + +object GoogleReporter { + + val DefaultBufferSize: Int = 10000 + val DefaultInterval: FiniteDuration = 5.seconds + + private case class Attributes(attributeMap: Map[String, String]) + private case class TruncatableString(value: String) + private case class Span( + name: String, + spanId: String, + parentSpanId: Option[String], + displayName: TruncatableString, + startTime: Instant, + var endTime: Instant, + attributes: Attributes + ) + + private case class Spans(spans: Seq[Span]) + + private implicit val instantFormat: RootJsonFormat[Instant] = new RootJsonFormat[Instant] { + override def write(obj: Instant): JsValue = obj.toString.toJson + override def read(json: JsValue): Instant = Instant.parse(json.convertTo[String]) + } + + private implicit val mapFormat = new RootJsonFormat[Map[String, String]] { + override def read(json: JsValue): Map[String, String] = sys.error("unimplemented") + override def write(obj: Map[String, String]): JsValue = { + val withValueObjects = obj.mapValues(value => JsObject("stringValue" -> JsObject("value" -> value.toJson))) + JsObject(withValueObjects) + } + } + + private implicit val attributeFormat: RootJsonFormat[Attributes] = jsonFormat1(Attributes) + private implicit val truncatableStringFormat: RootJsonFormat[TruncatableString] = jsonFormat1(TruncatableString) + private implicit val spanFormat: RootJsonFormat[Span] = jsonFormat7(Span) + private implicit val spansFormat: RootJsonFormat[Spans] = jsonFormat1(Spans) + +} diff --git a/src/main/scala/xyz/driver/core/reporting/NoTraceReporter.scala b/src/main/scala/xyz/driver/core/reporting/NoTraceReporter.scala new file mode 100644 index 0000000..9179f42 --- /dev/null +++ b/src/main/scala/xyz/driver/core/reporting/NoTraceReporter.scala @@ -0,0 +1,18 @@ +package xyz.driver.core +package reporting + +import com.typesafe.scalalogging.Logger + +import scala.concurrent.Future + +class NoTraceReporter(val logger: Logger) extends Reporter with ScalaLoggerLike { + override def traceWithOptionalParent[A]( + name: String, + tags: Map[String, String], + parent: Option[(SpanContext, Reporter.CausalRelation)])(op: SpanContext => A): A = op(SpanContext.fresh()) + override def traceWithOptionalParentAsync[A]( + name: String, + tags: Map[String, String], + parent: Option[(SpanContext, Reporter.CausalRelation)])(op: SpanContext => Future[A]): Future[A] = + op(SpanContext.fresh()) +} diff --git a/src/main/scala/xyz/driver/core/reporting/Reporter.scala b/src/main/scala/xyz/driver/core/reporting/Reporter.scala new file mode 100644 index 0000000..2425044 --- /dev/null +++ b/src/main/scala/xyz/driver/core/reporting/Reporter.scala @@ -0,0 +1,164 @@ +package xyz.driver.core.reporting + +import com.typesafe.scalalogging.Logger +import org.slf4j.helpers.NOPLogger +import xyz.driver.core.reporting.Reporter.CausalRelation + +import scala.concurrent.Future + +/** Context-aware diagnostic utility for distributed systems, combining logging and tracing. + * + * Diagnostic messages (i.e. logs) are a vital tool for monitoring applications. Tying such messages to an + * execution context, such as a stack trace, simplifies debugging greatly by giving insight to the chains of events + * that led to a particular message. In synchronous systems, execution contexts can easily be determined by an + * external observer, and, as such, do not need to be propagated explicitly to sub-components (e.g. a stack trace on + * the JVM shows all relevant information). In asynchronous systems and especially distributed systems however, + * execution contexts are not easily determined by an external observer and hence need to be explcictly passed across + * service boundaries. + * + * This reporter provides tracing and logging utilities that explicitly require references to execution contexts + * (called [[SpanContext]]s here) intended to be passed across service boundaries. It embraces Scala's + * implicit-parameter-as-a-context paradigm. + * + * Tracing is intended to be compatible with the + * [[https://github.com/opentracing/specification/blob/master/specification.md OpenTrace specification]], and hence its + * guidelines on naming and tagging apply to methods provided by this Reporter as well. + * + * Usage example: + * {{{ + * val reporter: Reporter = ??? + * object Repo { + * def getUserQuery(userId: String)(implicit ctx: SpanContext) = reporter.trace("query"){ implicit ctx => + * reporter.debug("Running query") + * // run query + * } + * } + * object Service { + * def getUser(userId: String)(implicit ctx: SpanContext) = reporter.trace("get_user"){ implicit ctx => + * reporter.debug("Getting user") + * Repo.getUserQuery(userId) + * } + * } + * reporter.traceRoot("static_get", Map("user" -> "john")) { implicit ctx => + * Service.getUser("john") + * } + * }}} + * + * Note that computing traces may be a more expensive operation than traditional logging frameworks provide (in terms + * of memory and processing). It should be used in interesting and actionable code paths. + * + * @define rootWarning Note: the idea of the reporting framework is to pass along references to traces as + * implicit parameters. This method should only be used for top-level traces when no parent + * traces are available. + */ +trait Reporter { + + def traceWithOptionalParent[A]( + name: String, + tags: Map[String, String], + parent: Option[(SpanContext, CausalRelation)])(op: SpanContext => A): A + def traceWithOptionalParentAsync[A]( + name: String, + tags: Map[String, String], + parent: Option[(SpanContext, CausalRelation)])(op: SpanContext => Future[A]): Future[A] + + /** Trace the execution of an operation, if no parent trace is available. + * + * $rootWarning + */ + def traceRoot[A](name: String, tags: Map[String, String])(op: SpanContext => A): A = + traceWithOptionalParent( + name, + tags, + None + )(op) + + /** Trace the execution of an asynchronous operation, if no parent trace is available. + * + * $rootWarning + * + * @see traceRoot + */ + def traceRootAsync[A](name: String, tags: Map[String, String])(op: SpanContext => Future[A]): Future[A] = + traceWithOptionalParentAsync( + name, + tags, + None + )(op) + + /** Trace the execution of an operation, in relation to a parent context. + * + * @param name The name of the operation. Note that this name should not be too specific. According to the + * OpenTrace RFC: "An operation name, a human-readable string which concisely represents the work done + * by the Span (for example, an RPC method name, a function name, or the name of a subtask or stage + * within a larger computation). The operation name should be the most general string that identifies a + * (statistically) interesting class of Span instances. That is, `"get_user"` is better than + * `"get_user/314159"`". + * @param tags Attributes associated with the traced event. Following the above example, if `"get_user"` is an + * operation name, a good tag would be `("account_id" -> 314159)`. + * @param relation Relation of the operation to its parent context. + * @param op The operation to be traced. The trace will complete once the operation returns. + * @param ctx Context of the parent trace. + * @tparam A Return type of the operation. + * @return The value of the child operation. + */ + def trace[A](name: String, tags: Map[String, String], relation: CausalRelation = CausalRelation.Child)( + op: /* implicit (gotta wait for Scala 3) */ SpanContext => A)(implicit ctx: SpanContext): A = + traceWithOptionalParent( + name, + tags, + Some(ctx -> relation) + )(op) + + /** Trace the operation of an asynchronous operation. + * + * Contrary to the synchronous version of this method, a trace is completed once the child operation completes + * (rather than returns). + * + * @see trace + */ + def traceAsync[A](name: String, tags: Map[String, String], relation: CausalRelation = CausalRelation.Child)( + op: /* implicit (gotta wait for Scala 3) */ SpanContext => Future[A])(implicit ctx: SpanContext): Future[A] = + traceWithOptionalParentAsync( + name, + tags, + Some(ctx -> relation) + )(op) + + /** Log a debug message. */ + def debug(message: String)(implicit ctx: SpanContext): Unit + + /** Log an informational message. */ + def info(message: String)(implicit ctx: SpanContext): Unit + + /** Log a warning message. */ + def warn(message: String)(implicit ctx: SpanContext): Unit + + /** Log a error message. */ + def error(message: String)(implicit ctx: SpanContext): Unit + + /** Log a error message with an associated throwable that caused the error condition. */ + def error(message: String, reason: Throwable)(implicit ctx: SpanContext): Unit + +} + +object Reporter { + + val NoReporter: Reporter = new NoTraceReporter(Logger.apply(NOPLogger.NOP_LOGGER)) + + /** A relation in cause. + * + * Corresponds to + * [[https://github.com/opentracing/specification/blob/master/specification.md#references-between-spans OpenTrace references between spans]] + */ + sealed trait CausalRelation + object CausalRelation { + + /** One event is the child of another. The parent completes once the child is complete. */ + case object Child extends CausalRelation + + /** One event follows from another, not necessarily being the parent. */ + case object Follows extends CausalRelation + } + +} diff --git a/src/main/scala/xyz/driver/core/reporting/ScalaLoggerLike.scala b/src/main/scala/xyz/driver/core/reporting/ScalaLoggerLike.scala new file mode 100644 index 0000000..c1131fb --- /dev/null +++ b/src/main/scala/xyz/driver/core/reporting/ScalaLoggerLike.scala @@ -0,0 +1,31 @@ +package xyz.driver.core.reporting +import com.typesafe.scalalogging.Logger + +trait ScalaLoggerLike extends Reporter { + + def logger: Logger + + override def debug(message: String)(implicit ctx: SpanContext): Unit = logger.debug(message) + override def info(message: String)(implicit ctx: SpanContext): Unit = logger.info(message) + override def warn(message: String)(implicit ctx: SpanContext): Unit = logger.warn(message) + override def error(message: String)(implicit ctx: SpanContext): Unit = logger.error(message) + override def error(message: String, reason: Throwable)(implicit ctx: SpanContext): Unit = + logger.error(message, reason) + +} + +object ScalaLoggerLike { + import scala.language.implicitConversions + + def defaultScalaLogger(json: Boolean = false): Logger = { + if (json) { + System.setProperty("logback.configurationFile", "deployed-logback.xml") + } else { + System.setProperty("logback.configurationFile", "logback.xml") + } + Logger.apply("application") + } + + implicit def toScalaLogger(reporter: ScalaLoggerLike): Logger = reporter.logger + +} diff --git a/src/main/scala/xyz/driver/core/reporting/SpanContext.scala b/src/main/scala/xyz/driver/core/reporting/SpanContext.scala new file mode 100644 index 0000000..58ab973 --- /dev/null +++ b/src/main/scala/xyz/driver/core/reporting/SpanContext.scala @@ -0,0 +1,11 @@ +package xyz.driver.core +package reporting +import scala.util.Random + +case class SpanContext private[core] (traceId: String, spanId: String) +object SpanContext { + def fresh() = SpanContext( + f"${Random.nextLong()}%02x${Random.nextLong()}%02x", + f"${Random.nextLong()}%02x" + ) +} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala index 788729a..c3b6bff 100644 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -10,7 +10,6 @@ import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.Name import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} -import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -20,8 +19,7 @@ class HttpRestServiceTransport( applicationVersion: String, actorSystem: ActorSystem, executionContext: ExecutionContext, - log: Logger, - time: TimeProvider) + log: Logger) extends ServiceTransport { protected implicit val execution: ExecutionContext = executionContext @@ -30,7 +28,7 @@ class HttpRestServiceTransport( def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - val requestTime = time.currentTime() + val requestTime = System.currentTimeMillis() val request = requestStub .withHeaders(context.contextHeaders.toSeq.map { @@ -51,11 +49,11 @@ class HttpRestServiceTransport( response.onComplete { case Success(r) => - val responseLatency = requestTime.durationTo(time.currentTime()) + val responseLatency = System.currentTimeMillis() - requestTime log.debug(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") case Failure(t: Throwable) => - val responseLatency = requestTime.durationTo(time.currentTime()) + val responseLatency = System.currentTimeMillis() - requestTime log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) }(executionContext) diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala index a3d942c..b598b33 100644 --- a/src/main/scala/xyz/driver/core/rest/Swagger.scala +++ b/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -13,24 +13,20 @@ import com.typesafe.scalalogging.Logger import io.swagger.models.Scheme import io.swagger.util.Json -import scala.reflect.runtime.universe -import scala.reflect.runtime.universe.Type import scala.util.control.NonFatal class Swagger( override val host: String, - override val schemes: List[Scheme], + accessSchemes: List[String], version: String, - val apiTypes: Seq[Type], + override val apiClasses: Set[Class[_]], val config: Config, val logger: Logger) extends SwaggerHttpService { - lazy val mirror = universe.runtimeMirror(getClass.getClassLoader) - - override val apiClasses = apiTypes.map { tpe => - mirror.runtimeClass(tpe.typeSymbol.asClass) - }.toSet + override val schemes = accessSchemes.map { s => + Scheme.forValue(s) + } // Note that the reason for overriding this is a subtle chain of causality: // @@ -52,15 +48,19 @@ class Swagger( try { val swagger: JSwagger = reader.read(apiClasses.asJava) - // Removing trailing spaces - swagger.setPaths( + val paths = if (swagger.getPaths == null) { + Map.empty + } else { swagger.getPaths.asScala - .map { - case (key, path) => - key.trim -> path - } - .toMap - .asJava) + } + + // Removing trailing spaces + val fixedPaths = paths.map { + case (key, path) => + key.trim -> path + } + + swagger.setPaths(fixedPaths.asJava) Json.pretty().writeValueAsString(swagger) } catch { diff --git a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala new file mode 100644 index 0000000..ff3424d --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala @@ -0,0 +1,19 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.auth.AuthProvider + +/** Authentication and authorization directives. */ +trait AuthDirectives extends AkkaDirectives { + + /** Authenticate a user based on service request headers and check if they have all given permissions. */ + def authenticateAndAuthorize[U <: User]( + authProvider: AuthProvider[U], + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authProvider.authorize(permissions: _*) + } + +} diff --git a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala new file mode 100644 index 0000000..5a6bbfd --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala @@ -0,0 +1,72 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpResponse, StatusCodes} +import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} + +/** Directives to handle Cross-Origin Resource Sharing (CORS). */ +trait CorsDirectives extends AkkaDirectives { + + /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request + * origin. + * + * In a microservice environment, it can be difficult to know in advance the exact origin + * from which requests may be issued [1]. For example, the request may come from a web page served from + * any of the services, on any namespace or from other documentation sites. In general, only a set + * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that + * implement CORS require exact specification of allowed origins, including full host name and scheme, + * in order to send credentials and headers with requests to other origins. + * + * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by + * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. + * + * Note that the simplicity of this implementation comes with two notable drawbacks: + * + * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. + * + * - Allowed methods and headers can not be customized on a per-request basis. All standard + * HTTP methods are allowed, and allowed headers are specified for all inner routes. + * + * This handler is not suited for cases where more fine-grained control of responses is required. + * + * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through + * a common gateway. + * + * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed + * origins. + * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. + * @param inner Route into which CORS headers will be injected. + */ + def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { + optionalHeaderValueByType[Origin](()) { maybeOrigin => + val allowedOrigins: HttpOriginRange = maybeOrigin match { + // Note that this is not a security issue: the client will never send credentials if the allowed + // origin is set to *. This case allows us to deal with clients that do not send an origin header. + case None => HttpOriginRange.* + case Some(requestOrigin) => + val allowedOrigin = requestOrigin.origins.find(origin => + allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) + allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) + } + + respondWithHeaders( + `Access-Control-Allow-Origin`.forRange(allowedOrigins), + `Access-Control-Allow-Credentials`(true), + `Access-Control-Allow-Headers`(allowedHeaders: _*), + `Access-Control-Expose-Headers`(allowedHeaders: _*) + ) { + options { // options is used during preflight check + complete( + HttpResponse(StatusCodes.OK) + .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) + } ~ inner // in case of non-preflight check we don't do anything special + } + } + } + +} + +object CorsDirectives extends CorsDirectives diff --git a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala new file mode 100644 index 0000000..0cd4ef1 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala @@ -0,0 +1,6 @@ +package xyz.driver.core +package rest +package directives + +trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers +object Directives extends Directives diff --git a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala new file mode 100644 index 0000000..07e32b0 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala @@ -0,0 +1,73 @@ +package xyz.driver.core +package rest +package directives + +import java.time.Instant +import java.util.UUID + +import akka.http.scaladsl.model.Uri.Path +import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} +import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import xyz.driver.core.time.Time + +import scala.util.control.NonFatal + +/** Akka-HTTP path matchers for suctom core types */ +trait PathMatchers { + + private def UuidInPath[T]: PathMatcher1[Id[T]] = + AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) + + def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) + case _ => Unmatched + } + } + + def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) + case _ => Unmatched + } + } + + private def timestampInPath: PathMatcher1[Long] = + PathMatcher("""[+-]?\d*""".r) flatMap { string => + try Some(string.toLong) + catch { case _: IllegalArgumentException => None } + } + + def InstantInPath: PathMatcher1[Instant] = + new PathMatcher1[Instant] { + def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { + case Path.Segment(head, tail) => + try Matched(tail, Tuple1(Instant.parse(head))) + catch { + case NonFatal(_) => Unmatched + } + case _ => Unmatched + } + } | timestampInPath.map(Instant.ofEpochMilli) + + def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) + + def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + refineV[NonEmpty](segment) match { + case Left(_) => Unmatched + case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) + } + case _ => Unmatched + } + } + + def RevisionInPath[T]: PathMatcher1[Revision[T]] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => + Some(Revision[T](string)) + } + +} diff --git a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala new file mode 100644 index 0000000..6c45d15 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala @@ -0,0 +1,40 @@ +package xyz.driver.core +package rest +package directives + +import java.util.UUID + +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.unmarshalling.Unmarshaller +import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} + +/** Akka-HTTP unmarshallers for custom core types. */ +trait Unmarshallers { + + implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = + Unmarshaller.strict[String, Id[A]] { str => + Id[A](UUID.fromString(str).toString) + } + + implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = + Unmarshaller.firstOf( + Unmarshaller.strict((JsString(_: String)) andThen reader.read), + stringToValueUnmarshaller[T] + ) + + implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = + Unmarshaller.strict[String, Revision[T]](Revision[T]) + + val jsValueToStringMarshaller: Marshaller[JsValue, String] = + Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) + + def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = + jsValueToStringMarshaller.compose[T](jsonFormat.write) + + val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = + Unmarshaller.strict[String, JsValue](value => JsonParser(value)) + + def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = + stringToJsValueUnmarshaller.map[T](jsonFormat.read) + +} diff --git a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala new file mode 100644 index 0000000..9d470ad --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala @@ -0,0 +1,33 @@ +package xyz.driver.core +package rest +package headers + +import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} + +import scala.util.Try + +/** Encapsulates trace context in an HTTP header for propagation across services. + * + * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) + * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is + * ignored. + */ +final case class Traceparent(traceId: String, spanId: String) extends ModeledCustomHeader[Traceparent] { + override def renderInRequests = true + override def renderInResponses = true + override val companion: Traceparent.type = Traceparent + override def value: String = f"01-$traceId-$spanId-00" +} +object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { + override val name = "traceparent" + override def parse(value: String) = Try { + val Array(version, traceId, spanId, _) = value.split("-") + require( + version == "01", + s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") + new Traceparent( + traceId, + spanId + ) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index c778b62..104261a 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -22,6 +22,8 @@ import scala.util.Try trait Service +object Service + trait HttpClient { def makeRequest(request: HttpRequest): Future[HttpResponse] } @@ -117,7 +119,8 @@ object `package` { "X-Content-Type-Options", "Strict-Transport-Security", AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader + AuthProvider.SetPermissionsTokenHeader, + "Traceparent" ) def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala index 118024a..324c8d8 100644 --- a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala +++ b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -12,16 +12,16 @@ class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers val config = ConfigFactory.parseString(""" |application { | cors { - | allowedMethods: ["GET", "PUT", "POST", "PATCH", "DELETE", "OPTIONS"] - | allowedOrigins: [{scheme: https, hostSuffix: example.com}] + | allowedOrigins: ["example.com"] | } |} """.stripMargin).withFallback(ConfigFactory.load) + val origin = Origin(HttpOrigin("https", Host("example.com"))) val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) val allowedMethods: collection.immutable.Seq[HttpMethod] = { import akka.http.scaladsl.model.HttpMethods._ - collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS) + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) } import scala.reflect.runtime.universe.typeOf @@ -37,7 +37,7 @@ class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers it should "respond with the correct CORS headers for the swagger OPTIONS route" in { val route = new TestApp(get(complete(StatusCodes.OK))) - Options(s"/api-docs/swagger.json") ~> route.appRoute ~> check { + Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { status shouldBe StatusCodes.OK headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods @@ -46,19 +46,17 @@ class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers it should "respond with the correct CORS headers for the test route" in { val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") ~> route.appRoute ~> check { + Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { status shouldBe StatusCodes.OK headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods } } it should "respond with the correct CORS headers for a concatenated route" in { val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) - Post(s"/api/v1/test") ~> route.appRoute ~> check { + Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { status shouldBe StatusCodes.OK headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods } } @@ -68,7 +66,6 @@ class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { status shouldBe StatusCodes.OK headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods } } @@ -77,8 +74,7 @@ class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers Get(s"/api/v1/test") .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) } } -- cgit v1.2.3