diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest')
9 files changed, 268 insertions, 24 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala index 788729a..c3b6bff 100644 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -10,7 +10,6 @@ import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.Name import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} -import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -20,8 +19,7 @@ class HttpRestServiceTransport( applicationVersion: String, actorSystem: ActorSystem, executionContext: ExecutionContext, - log: Logger, - time: TimeProvider) + log: Logger) extends ServiceTransport { protected implicit val execution: ExecutionContext = executionContext @@ -30,7 +28,7 @@ class HttpRestServiceTransport( def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - val requestTime = time.currentTime() + val requestTime = System.currentTimeMillis() val request = requestStub .withHeaders(context.contextHeaders.toSeq.map { @@ -51,11 +49,11 @@ class HttpRestServiceTransport( response.onComplete { case Success(r) => - val responseLatency = requestTime.durationTo(time.currentTime()) + val responseLatency = System.currentTimeMillis() - requestTime log.debug(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") case Failure(t: Throwable) => - val responseLatency = requestTime.durationTo(time.currentTime()) + val responseLatency = System.currentTimeMillis() - requestTime log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) }(executionContext) diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala index a3d942c..b598b33 100644 --- a/src/main/scala/xyz/driver/core/rest/Swagger.scala +++ b/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -13,24 +13,20 @@ import com.typesafe.scalalogging.Logger import io.swagger.models.Scheme import io.swagger.util.Json -import scala.reflect.runtime.universe -import scala.reflect.runtime.universe.Type import scala.util.control.NonFatal class Swagger( override val host: String, - override val schemes: List[Scheme], + accessSchemes: List[String], version: String, - val apiTypes: Seq[Type], + override val apiClasses: Set[Class[_]], val config: Config, val logger: Logger) extends SwaggerHttpService { - lazy val mirror = universe.runtimeMirror(getClass.getClassLoader) - - override val apiClasses = apiTypes.map { tpe => - mirror.runtimeClass(tpe.typeSymbol.asClass) - }.toSet + override val schemes = accessSchemes.map { s => + Scheme.forValue(s) + } // Note that the reason for overriding this is a subtle chain of causality: // @@ -52,15 +48,19 @@ class Swagger( try { val swagger: JSwagger = reader.read(apiClasses.asJava) - // Removing trailing spaces - swagger.setPaths( + val paths = if (swagger.getPaths == null) { + Map.empty + } else { swagger.getPaths.asScala - .map { - case (key, path) => - key.trim -> path - } - .toMap - .asJava) + } + + // Removing trailing spaces + val fixedPaths = paths.map { + case (key, path) => + key.trim -> path + } + + swagger.setPaths(fixedPaths.asJava) Json.pretty().writeValueAsString(swagger) } catch { diff --git a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala new file mode 100644 index 0000000..ff3424d --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala @@ -0,0 +1,19 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.auth.AuthProvider + +/** Authentication and authorization directives. */ +trait AuthDirectives extends AkkaDirectives { + + /** Authenticate a user based on service request headers and check if they have all given permissions. */ + def authenticateAndAuthorize[U <: User]( + authProvider: AuthProvider[U], + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authProvider.authorize(permissions: _*) + } + +} diff --git a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala new file mode 100644 index 0000000..5a6bbfd --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala @@ -0,0 +1,72 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpResponse, StatusCodes} +import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} + +/** Directives to handle Cross-Origin Resource Sharing (CORS). */ +trait CorsDirectives extends AkkaDirectives { + + /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request + * origin. + * + * In a microservice environment, it can be difficult to know in advance the exact origin + * from which requests may be issued [1]. For example, the request may come from a web page served from + * any of the services, on any namespace or from other documentation sites. In general, only a set + * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that + * implement CORS require exact specification of allowed origins, including full host name and scheme, + * in order to send credentials and headers with requests to other origins. + * + * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by + * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. + * + * Note that the simplicity of this implementation comes with two notable drawbacks: + * + * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. + * + * - Allowed methods and headers can not be customized on a per-request basis. All standard + * HTTP methods are allowed, and allowed headers are specified for all inner routes. + * + * This handler is not suited for cases where more fine-grained control of responses is required. + * + * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through + * a common gateway. + * + * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed + * origins. + * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. + * @param inner Route into which CORS headers will be injected. + */ + def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { + optionalHeaderValueByType[Origin](()) { maybeOrigin => + val allowedOrigins: HttpOriginRange = maybeOrigin match { + // Note that this is not a security issue: the client will never send credentials if the allowed + // origin is set to *. This case allows us to deal with clients that do not send an origin header. + case None => HttpOriginRange.* + case Some(requestOrigin) => + val allowedOrigin = requestOrigin.origins.find(origin => + allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) + allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) + } + + respondWithHeaders( + `Access-Control-Allow-Origin`.forRange(allowedOrigins), + `Access-Control-Allow-Credentials`(true), + `Access-Control-Allow-Headers`(allowedHeaders: _*), + `Access-Control-Expose-Headers`(allowedHeaders: _*) + ) { + options { // options is used during preflight check + complete( + HttpResponse(StatusCodes.OK) + .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) + } ~ inner // in case of non-preflight check we don't do anything special + } + } + } + +} + +object CorsDirectives extends CorsDirectives diff --git a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala new file mode 100644 index 0000000..0cd4ef1 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala @@ -0,0 +1,6 @@ +package xyz.driver.core +package rest +package directives + +trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers +object Directives extends Directives diff --git a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala new file mode 100644 index 0000000..07e32b0 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala @@ -0,0 +1,73 @@ +package xyz.driver.core +package rest +package directives + +import java.time.Instant +import java.util.UUID + +import akka.http.scaladsl.model.Uri.Path +import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} +import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import xyz.driver.core.time.Time + +import scala.util.control.NonFatal + +/** Akka-HTTP path matchers for suctom core types */ +trait PathMatchers { + + private def UuidInPath[T]: PathMatcher1[Id[T]] = + AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) + + def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) + case _ => Unmatched + } + } + + def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) + case _ => Unmatched + } + } + + private def timestampInPath: PathMatcher1[Long] = + PathMatcher("""[+-]?\d*""".r) flatMap { string => + try Some(string.toLong) + catch { case _: IllegalArgumentException => None } + } + + def InstantInPath: PathMatcher1[Instant] = + new PathMatcher1[Instant] { + def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { + case Path.Segment(head, tail) => + try Matched(tail, Tuple1(Instant.parse(head))) + catch { + case NonFatal(_) => Unmatched + } + case _ => Unmatched + } + } | timestampInPath.map(Instant.ofEpochMilli) + + def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) + + def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + refineV[NonEmpty](segment) match { + case Left(_) => Unmatched + case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) + } + case _ => Unmatched + } + } + + def RevisionInPath[T]: PathMatcher1[Revision[T]] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => + Some(Revision[T](string)) + } + +} diff --git a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala new file mode 100644 index 0000000..6c45d15 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala @@ -0,0 +1,40 @@ +package xyz.driver.core +package rest +package directives + +import java.util.UUID + +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.unmarshalling.Unmarshaller +import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} + +/** Akka-HTTP unmarshallers for custom core types. */ +trait Unmarshallers { + + implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = + Unmarshaller.strict[String, Id[A]] { str => + Id[A](UUID.fromString(str).toString) + } + + implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = + Unmarshaller.firstOf( + Unmarshaller.strict((JsString(_: String)) andThen reader.read), + stringToValueUnmarshaller[T] + ) + + implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = + Unmarshaller.strict[String, Revision[T]](Revision[T]) + + val jsValueToStringMarshaller: Marshaller[JsValue, String] = + Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) + + def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = + jsValueToStringMarshaller.compose[T](jsonFormat.write) + + val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = + Unmarshaller.strict[String, JsValue](value => JsonParser(value)) + + def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = + stringToJsValueUnmarshaller.map[T](jsonFormat.read) + +} diff --git a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala new file mode 100644 index 0000000..9d470ad --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala @@ -0,0 +1,33 @@ +package xyz.driver.core +package rest +package headers + +import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} + +import scala.util.Try + +/** Encapsulates trace context in an HTTP header for propagation across services. + * + * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) + * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is + * ignored. + */ +final case class Traceparent(traceId: String, spanId: String) extends ModeledCustomHeader[Traceparent] { + override def renderInRequests = true + override def renderInResponses = true + override val companion: Traceparent.type = Traceparent + override def value: String = f"01-$traceId-$spanId-00" +} +object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { + override val name = "traceparent" + override def parse(value: String) = Try { + val Array(version, traceId, spanId, _) = value.split("-") + require( + version == "01", + s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") + new Traceparent( + traceId, + spanId + ) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index c778b62..104261a 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -22,6 +22,8 @@ import scala.util.Try trait Service +object Service + trait HttpClient { def makeRequest(request: HttpRequest): Future[HttpResponse] } @@ -117,7 +119,8 @@ object `package` { "X-Content-Type-Options", "Strict-Transport-Security", AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader + AuthProvider.SetPermissionsTokenHeader, + "Traceparent" ) def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = |