aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
blob: 58a41437bc1e6827f47c83a34104295d77e6ff4c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package xyz.driver.core.rest

import java.sql.SQLException

import akka.http.scaladsl.model._
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import org.slf4j.MDC
import xyz.driver.core.rest
import xyz.driver.core.rest.errors._

import scala.compat.Platform.ConcurrentModificationException

trait DriverRoute {
  def log: Logger
  def config: Config

  def route: Route

  def routeWithDefaults: Route = {
    (defaultResponseHeaders & handleRejections(rejectionHandler) & handleExceptions(ExceptionHandler(exceptionHandler))) {
      route ~ defaultOptionsRoute
    }
  }

  protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = {
    import scala.collection.JavaConverters._
    config
      .getConfigList("application.cors.allowedOrigins")
      .asScala
      .map { c =>
        HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix")))
      }(scala.collection.breakOut)
  }

  protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = {
    import scala.collection.JavaConverters._
    config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey)
  }

  protected lazy val defaultCorsAllowedOrigin: Origin = {
    Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq])
  }

  protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = {
    val allowedOrigin =
      origin
        .filter { requestOrigin =>
          allowedCorsDomainSuffixes.exists { allowedOriginSuffix =>
            requestOrigin.origins.exists(o =>
              o.scheme == allowedOriginSuffix.scheme &&
                o.host.host.address.endsWith(allowedOriginSuffix.host.host.address()))
          }
        }
        .getOrElse(defaultCorsAllowedOrigin)

    `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*))
  }

  protected def respondWithAllCorsHeaders: Directive0 = {
    respondWithCorsAllowedHeaders tflatMap { _ =>
      respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ =>
        optionalHeaderValueByType[Origin](()) flatMap { origin =>
          respondWithHeader(corsAllowedOriginHeader(origin))
        }
      }
    }
  }

  protected def defaultOptionsRoute: Route = options {
    respondWithAllCorsHeaders {
      complete("OK")
    }
  }

  protected def defaultResponseHeaders: Directive0 = {
    extractRequest flatMap { request =>
      val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
      respondWithHeader(tracingHeader) & respondWithAllCorsHeaders
    }
  }

  protected def rejectionHandler: RejectionHandler =
    RejectionHandler
      .newBuilder()
      .handle {
        case rejection =>
          respondWithAllCorsHeaders {
            RejectionHandler.default(collection.immutable.Seq(rejection)).get
          }
      }
      .result()
      .seal

  /**
    * Override me for custom exception handling
    *
    * @return Exception handling route for exception type
    */
  protected def exceptionHandler: PartialFunction[Throwable, Route] = {
    case serviceException: ServiceException =>
      serviceExceptionHandler(serviceException)

    case is: IllegalStateException =>
      ctx =>
        log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is)
        errorResponse(ctx, StatusCodes.BadRequest, message = is.getMessage, is)(ctx)

    case cm: ConcurrentModificationException =>
      ctx =>
        log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm)
        errorResponse(
          ctx,
          StatusCodes.Conflict,
          "Resource was changed concurrently, try requesting a newer version",
          cm)(ctx)

    case se: SQLException =>
      ctx =>
        log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se)
        errorResponse(ctx, StatusCodes.InternalServerError, "Data access error", se)(ctx)

    case t: Exception =>
      ctx =>
        log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t)
        errorResponse(ctx, StatusCodes.InternalServerError, t.getMessage, t)(ctx)
  }

  protected def serviceExceptionHandler(serviceException: ServiceException): Route = {
    val statusCode = serviceException match {
      case e: InvalidInputException =>
        log.info("Invalid client input error", e)
        StatusCodes.BadRequest
      case e: InvalidActionException =>
        log.info("Invalid client action error", e)
        StatusCodes.Forbidden
      case e: ResourceNotFoundException =>
        log.info("Resource not found error", e)
        StatusCodes.NotFound
      case e: ExternalServiceException =>
        log.error("Error while calling another service", e)
        StatusCodes.InternalServerError
      case e: ExternalServiceTimeoutException =>
        log.error("Service timeout error", e)
        StatusCodes.GatewayTimeout
      case e: DatabaseException =>
        log.error("Database error", e)
        StatusCodes.InternalServerError
    }

    { (ctx: RequestContext) =>
      errorResponse(ctx, statusCode, serviceException.message, serviceException)(ctx)
    }
  }

  protected def errorResponse[T <: Exception](
      ctx: RequestContext,
      statusCode: StatusCode,
      message: String,
      exception: T): Route = {
    val trackingId = rest.extractTrackingId(ctx.request)
    MDC.put("trackingId", trackingId)
    complete(HttpResponse(statusCode, entity = message))
  }
}