diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core')
6 files changed, 206 insertions, 149 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 901d6e2..751bef7 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -1,35 +1,31 @@ package xyz.driver.core.app -import java.sql.SQLException - import akka.actor.ActorSystem import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError, MethodNotAllowed} +import akka.http.scaladsl.model.StatusCodes._ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server.RouteResult.route2HandlerFlow +import akka.http.scaladsl.server.RouteResult._ import akka.http.scaladsl.server._ import akka.http.scaladsl.{Http, HttpExt} import akka.stream.ActorMaterializer -import com.github.swagger.akka.SwaggerHttpService.{logger, toJavaTypeSet} +import com.github.swagger.akka.SwaggerHttpService._ import com.typesafe.config.Config import com.typesafe.scalalogging.Logger import io.swagger.models.Scheme import io.swagger.util.Json import org.slf4j.{LoggerFactory, MDC} import xyz.driver.core -import xyz.driver.core.rest -import xyz.driver.core.rest.{ContextHeaders, Swagger} +import xyz.driver.core.rest._ import xyz.driver.core.stats.SystemStats import xyz.driver.core.time.Time import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} -import xyz.driver.tracing.TracingDirectives.trace -import xyz.driver.tracing.{NoTracer, Tracer} +import xyz.driver.tracing.TracingDirectives._ +import xyz.driver.tracing._ -import scala.compat.Platform.ConcurrentModificationException import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{Await, ExecutionContext} import scala.reflect.runtime.universe._ import scala.util.Try import scala.util.control.NonFatal @@ -64,8 +60,7 @@ class DriverApp(appName: String, def stop(): Unit = { http.shutdownAllConnectionPools().onComplete { _ => Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer - val _ = actorSystem.terminate() - val terminated = Await.result(actorSystem.whenTerminated, 30.seconds) + val terminated = Await.result(actorSystem.terminate(), 30.seconds) val addressTerminated = if (terminated.addressTerminated) "is" else "is not" Console.print(s"${this.getClass.getName} App $addressTerminated stopped ") } @@ -74,69 +69,41 @@ class DriverApp(appName: String, private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def bindHttp(modules: Seq[Module]): Unit = { + protected def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = swaggerOverride(serviceTypes) - val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI + val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) + val combinedRoute = modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoute)(_ ~ _) - val _ = Future { - http.bindAndHandle( - route2HandlerFlow(extractHost { origin => - trace(tracer) { - extractClientIP { ip => - optionalHeaderValueByType[Origin](()) { - originHeader => - { - ctx => - val trackingId = rest.extractTrackingId(ctx.request) - MDC.put("trackingId", trackingId) - - val updatedStacktrace = - (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") - MDC.put("stack", updatedStacktrace) - - storeRequestContextToMdc(ctx.request, origin, ip) - - def requestLogging: Future[Unit] = Future { - log.info( - s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") - } - - val contextWithTrackingId = - ctx.withRequest( - ctx.request - .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) - .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) - - handleExceptions(ExceptionHandler(exceptionHandler))({ - c => - requestLogging.flatMap { _ => - val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - - val responseHeaders = List[HttpHeader]( - trackingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _) - }(c) - } - })(contextWithTrackingId) - } - } - } - } - }), - interface, - port - )(materializer) + (extractHost & extractClientIP & trace(tracer)) { + case (origin, ip) => + ctx => + val trackingId = extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + + val updatedStacktrace = + (extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) + + storeRequestContextToMdc(ctx.request, origin, ip) + + log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""") + + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) + + combinedRoute(contextWithTrackingId) } } + protected def bindHttp(modules: Seq[Module]): Unit = { + val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer) + } + private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = { MDC.put("origin", origin) @@ -181,58 +148,6 @@ class DriverApp(appName: String, } } - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(ctx, InternalServerError, "Data access error", se)(ctx) - - case t: Throwable => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx) - } - - protected def errorResponse[T <: Throwable](ctx: RequestContext, - statusCode: StatusCode, - message: String, - exception: T): Route = { - - val trackingId = rest.extractTrackingId(ctx.request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request)) - - MDC.put("trackingId", trackingId) - - optionalHeaderValueByType[Origin](()) { originHeader => - val responseHeaders = List[HttpHeader]( - tracingHeader, - allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) - ) - - respondWithHeaders(responseHeaders) { - complete(HttpResponse(statusCode, entity = message)) - } - } - } - protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { import spray.json._ import DefaultJsonProtocol._ @@ -334,11 +249,6 @@ class DriverApp(appName: String, } object DriverApp { - - private def allowOrigin(originHeader: Option[Origin]) = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - implicit def rejectionHandler: RejectionHandler = RejectionHandler .newBuilder() @@ -352,8 +262,8 @@ object DriverApp { Allow(methods), `Access-Control-Allow-Methods`(methods), allowOrigin(originHeader), - `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*), - `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*) + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) )) { complete(s"Supported methods: $names.") } @@ -362,5 +272,4 @@ object DriverApp { complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!") } .result() - } diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala index c6f979f..bbb29f4 100644 --- a/src/main/scala/xyz/driver/core/app/module.scala +++ b/src/main/scala/xyz/driver/core/app/module.scala @@ -3,7 +3,8 @@ package xyz.driver.core.app import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.server.Directives.complete import akka.http.scaladsl.server.{Route, RouteConcatenation} -import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} +import com.typesafe.scalalogging.Logger +import xyz.driver.core.rest.{DriverRoute, NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} import scala.reflect.runtime.universe._ @@ -21,17 +22,22 @@ trait Module { class EmptyModule extends Module { override val name: String = "Nothing" - override def route: Route = complete(StatusCodes.OK) - + override def route: Route = complete(StatusCodes.OK) override def routeTypes: Seq[Type] = Seq.empty[Type] } -class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module { - def routeTypes: Seq[Type] = Seq(routeType) +class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module { + private val driverRoute: DriverRoute = new DriverRoute { + override def route: Route = theRoute + override val log: Logger = xyz.driver.core.logging.NoLogger + } + + override def route: Route = driverRoute.routeWithDefaults + override def routeTypes: Seq[Type] = Seq(routeType) } /** - * Module implementation which may be used to composed a few + * Module implementation which may be used to compose multiple modules * * @param name more general name of the composite module, * must be provided as there is no good way to automatically @@ -39,12 +45,8 @@ class SimpleModule(override val name: String, override val route: Route, routeTy * @param modules modules to compose into a single one */ class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { - - override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) - + override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) - - override def activate(): Unit = modules.foreach(_.activate()) - - override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) + override def activate(): Unit = modules.foreach(_.activate()) + override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) } diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..eb9a31a --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,109 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route} +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors._ + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + def log: Logger + + def route: Route + + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route) + } + + protected def defaultResponseHeaders: Directive0 = { + (extractRequest & optionalHeaderValueByType[Origin](())) tflatMap { + case (request, originHeader) => + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request)) + val responseHeaders = List[HttpHeader]( + tracingHeader, + allowOrigin(originHeader), + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + ) + + respondWithHeaders(responseHeaders) + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case serviceException: ServiceException => + serviceExceptionHandler(serviceException) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(ctx, StatusCodes.BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(ctx, + StatusCodes.Conflict, + "Resource was changed concurrently, try requesting a newer version", + cm)(ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(ctx, StatusCodes.InternalServerError, "Data access error", se)(ctx) + + case t: Exception => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(ctx, StatusCodes.InternalServerError, t.getMessage, t)(ctx) + } + + protected def serviceExceptionHandler(serviceException: ServiceException): Route = { + val statusCode = serviceException match { + case e: InvalidInputException => + log.info("Invalid client input error", e) + StatusCodes.BadRequest + case e: InvalidActionException => + log.info("Invalid client action error", e) + StatusCodes.Forbidden + case e: ResourceNotFoundException => + log.info("Resource not found error", e) + StatusCodes.NotFound + case e: ExternalServiceException => + log.error("Error while calling another service", e) + StatusCodes.InternalServerError + case e: ExternalServiceTimeoutException => + log.error("Service timeout error", e) + StatusCodes.GatewayTimeout + case e: DatabaseException => + log.error("Database error", e) + StatusCodes.InternalServerError + } + + { (ctx: RequestContext) => + errorResponse(ctx, statusCode, serviceException.message, serviceException)(ctx) + } + } + + protected def errorResponse[T <: Exception](ctx: RequestContext, + statusCode: StatusCode, + message: String, + exception: T): Route = { + val trackingId = rest.extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + complete(HttpResponse(statusCode, entity = message)) + } +} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala index 1e95811..376b154 100644 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -4,9 +4,12 @@ import akka.actor.ActorSystem import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.RawHeader import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.TcpIdleTimeoutException import com.typesafe.scalalogging.Logger import org.slf4j.MDC import xyz.driver.core.Name +import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} @@ -55,18 +58,27 @@ class HttpRestServiceTransport(applicationName: Name[App], log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) }(executionContext) - response + response.recoverWith { + case _: TcpIdleTimeoutException => + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Future.failed(ExternalServiceTimeoutException(serviceCalled)) + case t: Throwable => Future.failed(t) + } } - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = { + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { - sendRequestGetResponse(context)(requestStub) map { response => + sendRequestGetResponse(context)(requestStub) flatMap { response => if (response.status == StatusCodes.NotFound) { - Unmarshal(HttpEntity.Empty: ResponseEntity) + Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) } else if (response.status.isFailure()) { - throw new Exception(s"Http status is failure ${response.status} for ${requestStub.method} ${requestStub.uri}") + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Unmarshal(response.entity).to[String] flatMap { error => + Future.failed(ExternalServiceException(serviceCalled, error)) + } } else { - Unmarshal(response.entity) + Future.successful(Unmarshal(response.entity)) } } } diff --git a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala new file mode 100644 index 0000000..e91a3c2 --- /dev/null +++ b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala @@ -0,0 +1,23 @@ +package xyz.driver.core.rest.errors + +sealed abstract class ServiceException extends Exception { + def message: String +} + +final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException + +final case class InvalidActionException(override val message: String = "This action is not allowed") + extends ServiceException + +final case class ResourceNotFoundException(override val message: String = "Resource not found") + extends ServiceException + +final case class ExternalServiceException(serviceName: String, serviceMessage: String) extends ServiceException { + override def message = s"Error while calling '$serviceName': $serviceMessage" +} + +final case class ExternalServiceTimeoutException(serviceName: String) extends ServiceException { + override def message = s"$serviceName took too long to respond" +} + +final case class DatabaseException(override val message: String = "Database access error") extends ServiceException diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala index 942ca3a..531cd8a 100644 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ b/src/main/scala/xyz/driver/core/rest/package.scala @@ -6,6 +6,7 @@ import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, Stat import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer import akka.stream.scaladsl.Flow import akka.util.ByteString import xyz.driver.tracing.TracingDirectives @@ -25,7 +26,8 @@ trait ServiceTransport { def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] } final case class Pagination(pageSize: Int, pageNumber: Int) |