package xyz.driver.core.rest
import java.sql.SQLException
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route}
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import org.slf4j.MDC
import xyz.driver.core.rest
import xyz.driver.core.rest.errors._
import scala.compat.Platform.ConcurrentModificationException
trait DriverRoute {
def log: Logger
def config: Config
def route: Route
def routeWithDefaults: Route = {
(defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) {
route ~ defaultOptionsRoute
}
}
protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
import scala.collection.JavaConverters._
config
.getConfigList("application.cors.allowedOrigins")
.asScala
.map { c =>
HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
}(scala.collection.breakOut)
}
protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
import scala.collection.JavaConverters._
config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
}
protected lazy val defaultCorsAllowedOrigin: Origin =
Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
val allowedOrigin =
origin
.filter { requestOrigin =>
allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
requestOrigin.origins.exists(o =>
o.scheme == allowedOriginSuffix.scheme &&
o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
}
}
.getOrElse(defaultCorsAllowedOrigin)
`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
}
protected def respondWithAllCorsHeaders: Directive0 = {
respondWithCorsAllowedHeaders tflatMap { _ =>
respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
optionalHeaderValueByType[Origin](()) flatMap { origin =>
respondWithHeader(corsAllowedOriginHeader(origin))
}
}
}
}
protected def defaultOptionsRoute: Route = options {
respondWithAllCorsHeaders {
complete("OK")
}
}
protected def defaultResponseHeaders: Directive0 = {
extractRequest flatMap { request =>
val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
respondWithHeader(tracingHeader) & respondWithAllCorsHeaders
}
}
/**
* Override me for custom exception handling
*
* @return Exception handling route for exception type
*/
protected def exceptionHandler: PartialFunction[Throwable, Route] = {
case serviceException: ServiceException =>
serviceExceptionHandler(serviceException)
case is: IllegalStateException =>
ctx =>
log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is)
errorResponse(ctx, StatusCodes.BadRequest, message = is.getMessage, is)(ctx)
case cm: ConcurrentModificationException =>
ctx =>
log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm)
errorResponse(
ctx,
StatusCodes.Conflict,
"Resource was changed concurrently, try requesting a newer version",
cm)(ctx)
case se: SQLException =>
ctx =>
log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se)
errorResponse(ctx, StatusCodes.InternalServerError, "Data access error", se)(ctx)
case t: Exception =>
ctx =>
log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t)
errorResponse(ctx, StatusCodes.InternalServerError, t.getMessage, t)(ctx)
}
protected def serviceExceptionHandler(serviceException: ServiceException): Route = {
val statusCode = serviceException match {
case e: InvalidInputException =>
log.info("Invalid client input error", e)
StatusCodes.BadRequest
case e: InvalidActionException =>
log.info("Invalid client action error", e)
StatusCodes.Forbidden
case e: ResourceNotFoundException =>
log.info("Resource not found error", e)
StatusCodes.NotFound
case e: ExternalServiceException =>
log.error("Error while calling another service", e)
StatusCodes.InternalServerError
case e: ExternalServiceTimeoutException =>
log.error("Service timeout error", e)
StatusCodes.GatewayTimeout
case e: DatabaseException =>
log.error("Database error", e)
StatusCodes.InternalServerError
}
{ (ctx: RequestContext) =>
errorResponse(ctx, statusCode, serviceException.message, serviceException)(ctx)
}
}
protected def errorResponse[T <: Exception](
ctx: RequestContext,
statusCode: StatusCode,
message: String,
exception: T): Route = {
val trackingId = rest.extractTrackingId(ctx.request)
MDC.put("trackingId", trackingId)
complete(HttpResponse(statusCode, entity = message))
}
}