diff options
Diffstat (limited to 'jvm/src/main/scala/xyz/driver/core/rest')
18 files changed, 1234 insertions, 0 deletions
diff --git a/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..55f39ba --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,111 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model.{StatusCodes, _} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors._ + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + def log: Logger + + def route: Route + + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route + } + } + + protected def defaultResponseHeaders: Directive0 = { + extractRequest flatMap { request => + // Needs to happen before any request processing, so all the log messages + // associated with processing of this request are having this `trackingId` + val trackingId = rest.extractTrackingId(request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + MDC.put("trackingId", trackingId) + + // This header will eliminate the risk of LB trying to reuse a connection + // that already timed out on the server side by completely rejecting keep-alive + val rejectKeepAlive = Connection("close") + + respondWithHeaders(tracingHeader, rejectKeepAlive) + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case serviceException: ServiceException => + serviceExceptionHandler(serviceException) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( + ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) + + case t: Exception => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) + } + + protected def serviceExceptionHandler(serviceException: ServiceException): Route = { + val statusCode = serviceException match { + case e: InvalidInputException => + log.info("Invalid client input error", e) + StatusCodes.BadRequest + case e: InvalidActionException => + log.info("Invalid client action error", e) + StatusCodes.Forbidden + case e: ResourceNotFoundException => + log.info("Resource not found error", e) + StatusCodes.NotFound + case e: ExternalServiceException => + log.error("Error while calling another service", e) + StatusCodes.InternalServerError + case e: ExternalServiceTimeoutException => + log.error("Service timeout error", e) + StatusCodes.GatewayTimeout + case e: DatabaseException => + log.error("Database error", e) + StatusCodes.InternalServerError + } + + { (ctx: RequestContext) => + import xyz.driver.core.json.serviceExceptionFormat + val entity = + HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) + errorResponse(statusCode, entity, serviceException)(ctx) + } + } + + protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = + errorResponse(statusCode, HttpEntity(message), exception) + + protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { + complete(HttpResponse(statusCode, entity = entity)) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala new file mode 100644 index 0000000..788729a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -0,0 +1,89 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.TcpIdleTimeoutException +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.Name +import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} +import xyz.driver.core.time.provider.TimeProvider + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class HttpRestServiceTransport( + applicationName: Name[App], + applicationVersion: String, + actorSystem: ActorSystem, + executionContext: ExecutionContext, + log: Logger, + time: TimeProvider) + extends ServiceTransport { + + protected implicit val execution: ExecutionContext = executionContext + + protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + + val requestTime = time.currentTime() + + val request = requestStub + .withHeaders(context.contextHeaders.toSeq.map { + case (ContextHeaders.TrackingIdHeader, _) => + RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) + case (ContextHeaders.StacktraceHeader, _) => + RawHeader( + ContextHeaders.StacktraceHeader, + Option(MDC.get("stack")) + .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) + .getOrElse("")) + case (header, headerValue) => RawHeader(header, headerValue) + }: _*) + + log.debug(s"Sending request to ${request.method} ${request.uri}") + + val response = httpClient.makeRequest(request) + + response.onComplete { + case Success(r) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.debug(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") + + case Failure(t: Throwable) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) + }(executionContext) + + response.recoverWith { + case _: TcpIdleTimeoutException => + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Future.failed(ExternalServiceTimeoutException(serviceCalled)) + case t: Throwable => Future.failed(t) + } + } + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { + + sendRequestGetResponse(context)(requestStub) flatMap { response => + if (response.status == StatusCodes.NotFound) { + Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) + } else if (response.status.isFailure()) { + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Unmarshal(response.entity).to[String] flatMap { errorString => + import spray.json._ + import xyz.driver.core.json._ + val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption + Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) + } + } else { + Future.successful(Unmarshal(response.entity)) + } + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..f33bf9d --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => + val oldValue = oldObj.fields.getOrElse(key, JsNull) + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala new file mode 100644 index 0000000..2854257 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala @@ -0,0 +1,67 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.scaladsl.{Keep, Sink, Source} +import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} +import xyz.driver.core.Name + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +class PooledHttpClient( + baseUri: Uri, + applicationName: Name[App], + applicationVersion: String, + requestRateLimit: Int = 64, + requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) + extends HttpClient { + + private val host = baseUri.authority.host.toString() + private val port = baseUri.effectivePort + private val scheme = baseUri.scheme + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + private val pool = if (scheme.equalsIgnoreCase("https")) { + Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } else { + Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } + + private val queue = Source + .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) + .via(pool) + .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) + .toMat(Sink.foreach({ + case ((Success(resp), p)) => p.success(resp) + case ((Failure(e), p)) => p.failure(e) + }))(Keep.left) + .run + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + val responsePromise = Promise[HttpResponse]() + + queue.offer(request -> responsePromise).flatMap { + case QueueOfferResult.Enqueued => + responsePromise.future + case QueueOfferResult.Dropped => + Future.failed(new Exception(s"Request queue to the host $host is overflown")) + case QueueOfferResult.Failure(ex) => + Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala new file mode 100644 index 0000000..c0e9f99 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala @@ -0,0 +1,26 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} +import com.typesafe.config.Config +import xyz.driver.core.Name + +import scala.concurrent.ExecutionContext + +trait ProxyRoute extends DriverRoute { + implicit val executionContext: ExecutionContext + val config: Config + val httpClient: HttpClient + + protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => + val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") + val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") + + val originalUri = ctx.request.uri + val originalRequest = ctx.request + + val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) + val newRequest = originalRequest.withUri(newUri) + + httpClient.makeRequest(newRequest).map(RouteResult.Complete) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala b/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala new file mode 100644 index 0000000..8d46d72 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala @@ -0,0 +1,72 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +trait RestService extends Service { + + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json._ + + protected implicit val exec: ExecutionContext + protected implicit val materializer: Materializer + + implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { + def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = + if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] + } + + protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = + OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) + + protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) + + protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) + + protected def jsonEntity(json: JsValue): RequestEntity = + HttpEntity(ContentTypes.`application/json`, json.compactPrint) + + protected def mergePatchJsonEntity(json: JsValue): RequestEntity = + HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) + + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) + + protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) + + protected def postJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) + + protected def putJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) + + protected def patchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) + + protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) + + protected def endpointUri(baseUri: Uri, path: String): Uri = + baseUri.withPath(Uri.Path(path)) + + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala new file mode 100644 index 0000000..964a5a2 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala @@ -0,0 +1,29 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.ActorMaterializer +import xyz.driver.core.Name + +import scala.concurrent.Future + +class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) + extends HttpClient { + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private val client = Http()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + client.singleRequest(request, settings = connectionPoolSettings) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala b/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala new file mode 100644 index 0000000..a3d942c --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -0,0 +1,127 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile +import akka.stream.ActorAttributes +import akka.stream.scaladsl.{Framing, StreamConverters} +import akka.util.ByteString +import com.github.swagger.akka.SwaggerHttpService +import com.github.swagger.akka.model._ +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import io.swagger.util.Json + +import scala.reflect.runtime.universe +import scala.reflect.runtime.universe.Type +import scala.util.control.NonFatal + +class Swagger( + override val host: String, + override val schemes: List[Scheme], + version: String, + val apiTypes: Seq[Type], + val config: Config, + val logger: Logger) + extends SwaggerHttpService { + + lazy val mirror = universe.runtimeMirror(getClass.getClassLoader) + + override val apiClasses = apiTypes.map { tpe => + mirror.runtimeClass(tpe.typeSymbol.asClass) + }.toSet + + // Note that the reason for overriding this is a subtle chain of causality: + // + // 1. Some of our endpoints require a single trailing slash and will not + // function if it is omitted + // 2. Swagger omits trailing slashes in its generated api doc + // 3. To work around that, a space is added after the trailing slash in the + // swagger Path annotations + // 4. This space is removed manually in the code below + // + // TODO: Ideally we'd like to drop this custom override and fix the issue in + // 1, by dropping the slash requirement and accepting api endpoints with and + // without trailing slashes. This will require inspecting and potentially + // fixing all service endpoints. + override def generateSwaggerJson: String = { + import io.swagger.models.{Swagger => JSwagger} + + import scala.collection.JavaConverters._ + try { + val swagger: JSwagger = reader.read(apiClasses.asJava) + + // Removing trailing spaces + swagger.setPaths( + swagger.getPaths.asScala + .map { + case (key, path) => + key.trim -> path + } + .toMap + .asJava) + + Json.pretty().writeValueAsString(swagger) + } catch { + case NonFatal(t) => + logger.error("Issue with creating swagger.json", t) + throw t + } + } + + override val basePath: String = config.getString("swagger.basePath") + override val apiDocsPath: String = config.getString("swagger.docsPath") + + override val info = Info( + config.getString("swagger.apiInfo.description"), + version, + config.getString("swagger.apiInfo.title"), + config.getString("swagger.apiInfo.termsOfServiceUrl"), + contact = Some( + Contact( + config.getString("swagger.apiInfo.contact.name"), + config.getString("swagger.apiInfo.contact.url"), + config.getString("swagger.apiInfo.contact.email") + )), + license = Some( + License( + config.getString("swagger.apiInfo.license"), + config.getString("swagger.apiInfo.licenseUrl") + )), + vendorExtensions = Map.empty[String, AnyRef] + ) + + /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ + private def getTemplatedResource( + resourceName: String, + contentType: ContentType, + substitution: (String, String)): Route = get { + Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { + case Some(ResourceFile(url, length @ _, _)) => + extractSettings { settings => + val stream = StreamConverters + .fromInputStream(() => url.openStream()) + .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) + .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) + .map { line => + line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) + } + .map(line => ByteString(line + "\n")) + complete( + HttpEntity(contentType, stream) + ) + } + case None => reject + } + } + + def swaggerUI: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui") + +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala new file mode 100644 index 0000000..5007774 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val permissionsMap = permissions.map(_ -> true).toMap + Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala new file mode 100644 index 0000000..82edcc7 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala @@ -0,0 +1,73 @@ +package xyz.driver.core.rest.auth + +import akka.http.scaladsl.model.headers.HttpChallenges +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import com.typesafe.scalalogging.Logger +import xyz.driver.core._ +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ServiceRequestContext, serviceContext} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import scalaz.Scalaz.futureInstance +import scalaz.OptionT + +abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( + implicit execution: ExecutionContext) { + + import akka.http.scaladsl.server._ + import Directives._ + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param ctx set of request values which can be relevant to authenticate user + * @return authenticated user + */ + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] + + /** + * Verifies if a service context is authenticated and authorized to have `permissions` + */ + def authorize( + context: ServiceRequestContext, + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + onComplete { + (for { + authToken <- OptionT.optionT(Future.successful(context.authToken)) + user <- authenticatedUser(context) + authCtx = context.withAuthenticatedUser(authToken, user) + authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT + + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + allAuthorized = permissions.forall(authorizationResult.authorized.getOrElse(_, false)) + } yield (cachedPermissionsAuthCtx, allAuthorized)).run + } flatMap { + case Success(Some((authCtx, true))) => provide(authCtx) + case Success(Some((authCtx, false))) => + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + log.warn( + s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + case Success(None) => + val challenge = HttpChallenges.basic("Failed to authenticate user") + log.warn(s"Failed to authenticate user to verify ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + case Failure(t) => + log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t) + reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) + } + } + + /** + * Verifies if request is authenticated and authorized to have `permissions` + */ + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + serviceContext flatMap { ctx => + authorize(ctx, permissions: _*) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala new file mode 100644 index 0000000..1a5e9be --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala @@ -0,0 +1,11 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala new file mode 100644 index 0000000..efe28c9 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala @@ -0,0 +1,22 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, PermissionsToken} + +import scalaz.Scalaz.mapMonoid +import scalaz.Semigroup +import scalaz.syntax.semigroup._ + +final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) +object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) + + implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { + private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) + private implicit val permissionsTokenSemigroup = + Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) + + override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { + AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala new file mode 100644 index 0000000..66de4ef --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala @@ -0,0 +1,55 @@ +package xyz.driver.core.rest.auth + +import java.nio.file.{Files, Path} +import java.security.{KeyFactory, PublicKey} +import java.security.spec.X509EncodedKeySpec + +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future +import scalaz.syntax.std.boolean._ + +class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } +} + +object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim + val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) + val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala new file mode 100644 index 0000000..131e7fc --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.Scalaz.{futureInstance, listInstance} +import scalaz.syntax.semigroup._ +import scalaz.syntax.traverse._ + +class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = + permissions.forall(permissionsMap.getOrElse(_, false)) + + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (allAuthorized(authResult.authorized)) Future.successful(authResult) + else { + authorization.userHasPermissions(user, permissions).map(authResult |+| _) + } + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala new file mode 100644 index 0000000..db289de --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala @@ -0,0 +1,23 @@ +package xyz.driver.core.rest.errors + +sealed abstract class ServiceException(val message: String) extends Exception(message) + +final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) + +final case class InvalidActionException(override val message: String = "This action is not allowed") + extends ServiceException(message) + +final case class ResourceNotFoundException(override val message: String = "Resource not found") + extends ServiceException(message) + +final case class ExternalServiceException( + serviceName: String, + serviceMessage: String, + serviceException: Option[ServiceException]) + extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") + +final case class ExternalServiceTimeoutException(serviceName: String) + extends ServiceException(s"$serviceName took too long to respond") + +final case class DatabaseException(override val message: String = "Database access error") + extends ServiceException(message) diff --git a/jvm/src/main/scala/xyz/driver/core/rest/package.scala b/jvm/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..f85c39a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,286 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import xyz.driver.tracing.TracingDirectives + +import scala.concurrent.Future +import scala.util.Try +import scalaz.{Functor, OptionT} +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ + +trait Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] +} + +sealed trait SortingOrder +object SortingOrder { + case object Asc extends SortingOrder + case object Desc extends SortingOrder +} + +final case class SortingField(name: String, sortingOrder: SortingOrder) +final case class Sorting(sortingFields: Seq[SortingField]) + +final case class Pagination(pageSize: Int, pageNumber: Int) { + require(pageSize > 0, "Page size must be greater than zero") + require(pageNumber > 0, "Page number must be greater than zero") + + def offset: Int = pageSize * (pageNumber - 1) +} + +final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) + +object ListResponse { + + def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = + ListResponse( + items = items, + meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) + + final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) + + object Meta { + def apply(itemsCount: Int, pagination: Pagination): Meta = + Meta(itemsCount, pagination.pageNumber, pagination.pageSize) + } + +} + +object `package` { + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } + + object ContextHeaders { + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val ClientFingerprintHeader: String = "X-Client-Fingerprint" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val OriginatingIpHeader: String = "X-Forwarded-For" + val ResourceCount: String = "X-Resource-Count" + val PageCount: String = "X-Page-Count" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName + } + + object AuthProvider { + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" + } + + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.ClientFingerprintHeader, + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + ContextHeaders.OriginatingIpHeader, + ContextHeaders.ResourceCount, + ContextHeaders.PageCount, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader + ) + + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + def serviceContext: Directive1[ServiceRequestContext] = { + extractClientIP flatMap { remoteAddress => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } + } + + def respondWithCorsAllowedHeaders: Directive0 = { + respondWithHeaders( + List[HttpHeader]( + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + )) + } + + def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { + respondWithHeader { + `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) + } + } + + def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { + respondWithHeaders( + List[HttpHeader]( + Allow(methods.to[collection.immutable.Seq]), + `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) + )) + } + + def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = + new ServiceRequestContext( + extractTrackingId(request), + extractOriginatingIP(request, remoteAddress), + extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name === ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractFingerprintHash(request: HttpRequest): Option[String] = { + request.headers + .find(_.name === ContextHeaders.ClientFingerprintHeader) + .map(_.value()) + } + + def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { + request.headers + .find(_.name === ContextHeaders.OriginatingIpHeader) + .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) + .orElse(remoteAddress.toOption) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || + h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + + val paginated: Directive1[Pagination] = + parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) + + private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = + (pageSizeOpt, pageNumberOpt) match { + case (Some(size), Some(number)) => Option(Pagination(size, number)) + case (None, None) => Option.empty[Pagination] + case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") + } + + val optionalPagination: Directive1[Option[Pagination]] = + parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) + + def paginationQuery(pagination: Pagination) = + Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) + + private def extractSorting(sortingString: Option[String]): Sorting = { + val sortingFields = sortingString.fold(Seq.empty[SortingField])( + _.split(",") + .filter(_.length > 0) + .map { sortingParam => + if (sortingParam.startsWith("-")) { + SortingField(sortingParam.substring(1), SortingOrder.Desc) + } else { + val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam + SortingField(fieldName, SortingOrder.Asc) + } + } + .toSeq) + + Sorting(sortingFields) + } + + val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) + + def sortingQuery(sorting: Sorting): Seq[(String, String)] = { + val sortingString = sorting.sortingFields + .map { sortingField => + sortingField.sortingOrder match { + case SortingOrder.Asc => sortingField.name + case SortingOrder.Desc => s"-${sortingField.name}" + } + } + .mkString(",") + Seq("sort" -> sortingString) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala new file mode 100644 index 0000000..55f1a2e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala @@ -0,0 +1,24 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} + +trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala new file mode 100644 index 0000000..775106e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -0,0 +1,74 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators + +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ + +class ServiceRequestContext( + val trackingId: String = generators.nextUuid().toString, + val originatingIp: Option[InetAddress] = None, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, originatingIp, contextHeaders) + .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => + trackingId === ctx.trackingId && + originatingIp == originatingIp && + contextHeaders === ctx.contextHeaders + case _ => false + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} + +class AuthorizedServiceRequestContext[U <: User]( + override val trackingId: String = generators.nextUuid().toString, + override val originatingIp: Option[InetAddress] = None, + override val contextHeaders: Map[String, String] = Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} |