aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzachdriver <zach@driver.xyz>2017-11-01 10:23:25 -0700
committerGitHub <noreply@github.com>2017-11-01 10:23:25 -0700
commitd480ea203d836739534fd8d27005a1e0a168c30f (patch)
treea444fbafefccf232e68e72b3f592cd3e0bca3396
parent0cb06d70bd91e1e6a4ab9d97851ef9db7aaedfd6 (diff)
parent595d199f5e41c8e48131cec98b23452bc7ed6ef1 (diff)
downloaddriver-core-d480ea203d836739534fd8d27005a1e0a168c30f.tar.gz
driver-core-d480ea203d836739534fd8d27005a1e0a168c30f.tar.bz2
driver-core-d480ea203d836739534fd8d27005a1e0a168c30f.zip
Merge pull request #73 from drivergroup/zsmith/route-traitv1.6.0v1.5.2
Add DriverRoute trait and APIError types
-rw-r--r--src/main/scala/xyz/driver/core/app/DriverApp.scala167
-rw-r--r--src/main/scala/xyz/driver/core/app/module.scala28
-rw-r--r--src/main/scala/xyz/driver/core/rest/DriverRoute.scala109
-rw-r--r--src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala24
-rw-r--r--src/main/scala/xyz/driver/core/rest/errors/serviceException.scala23
-rw-r--r--src/main/scala/xyz/driver/core/rest/package.scala4
-rw-r--r--src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala89
-rw-r--r--src/test/scala/xyz/driver/core/rest/RestTest.scala (renamed from src/test/scala/xyz/driver/core/RestTest.scala)3
8 files changed, 296 insertions, 151 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala
index 901d6e2..751bef7 100644
--- a/src/main/scala/xyz/driver/core/app/DriverApp.scala
+++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala
@@ -1,35 +1,31 @@
package xyz.driver.core.app
-import java.sql.SQLException
-
import akka.actor.ActorSystem
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
-import akka.http.scaladsl.model.StatusCodes.{BadRequest, Conflict, InternalServerError, MethodNotAllowed}
+import akka.http.scaladsl.model.StatusCodes._
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
-import akka.http.scaladsl.server.RouteResult.route2HandlerFlow
+import akka.http.scaladsl.server.RouteResult._
import akka.http.scaladsl.server._
import akka.http.scaladsl.{Http, HttpExt}
import akka.stream.ActorMaterializer
-import com.github.swagger.akka.SwaggerHttpService.{logger, toJavaTypeSet}
+import com.github.swagger.akka.SwaggerHttpService._
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import io.swagger.models.Scheme
import io.swagger.util.Json
import org.slf4j.{LoggerFactory, MDC}
import xyz.driver.core
-import xyz.driver.core.rest
-import xyz.driver.core.rest.{ContextHeaders, Swagger}
+import xyz.driver.core.rest._
import xyz.driver.core.stats.SystemStats
import xyz.driver.core.time.Time
import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
-import xyz.driver.tracing.TracingDirectives.trace
-import xyz.driver.tracing.{NoTracer, Tracer}
+import xyz.driver.tracing.TracingDirectives._
+import xyz.driver.tracing._
-import scala.compat.Platform.ConcurrentModificationException
import scala.concurrent.duration._
-import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.{Await, ExecutionContext}
import scala.reflect.runtime.universe._
import scala.util.Try
import scala.util.control.NonFatal
@@ -64,8 +60,7 @@ class DriverApp(appName: String,
def stop(): Unit = {
http.shutdownAllConnectionPools().onComplete { _ =>
Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer
- val _ = actorSystem.terminate()
- val terminated = Await.result(actorSystem.whenTerminated, 30.seconds)
+ val terminated = Await.result(actorSystem.terminate(), 30.seconds)
val addressTerminated = if (terminated.addressTerminated) "is" else "is not"
Console.print(s"${this.getClass.getName} App $addressTerminated stopped ")
}
@@ -74,69 +69,41 @@ class DriverApp(appName: String,
private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
request.headers.find(_.name().toLowerCase === headerName).map(_.value())
- protected def bindHttp(modules: Seq[Module]): Unit = {
+ protected def appRoute: Route = {
val serviceTypes = modules.flatMap(_.routeTypes)
val swaggerService = swaggerOverride(serviceTypes)
- val swaggerRoutes = swaggerService.routes ~ swaggerService.swaggerUI
+ val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI
val versionRt = versionRoute(version, gitHash, time.currentTime())
+ val combinedRoute = modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoute)(_ ~ _)
- val _ = Future {
- http.bindAndHandle(
- route2HandlerFlow(extractHost { origin =>
- trace(tracer) {
- extractClientIP { ip =>
- optionalHeaderValueByType[Origin](()) {
- originHeader =>
- {
- ctx =>
- val trackingId = rest.extractTrackingId(ctx.request)
- MDC.put("trackingId", trackingId)
-
- val updatedStacktrace =
- (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->")
- MDC.put("stack", updatedStacktrace)
-
- storeRequestContextToMdc(ctx.request, origin, ip)
-
- def requestLogging: Future[Unit] = Future {
- log.info(
- s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""")
- }
-
- val contextWithTrackingId =
- ctx.withRequest(
- ctx.request
- .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
- .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))
-
- handleExceptions(ExceptionHandler(exceptionHandler))({
- c =>
- requestLogging.flatMap { _ =>
- val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId)
-
- val responseHeaders = List[HttpHeader](
- trackingHeader,
- allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*)
- )
-
- respondWithHeaders(responseHeaders) {
- modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _)
- }(c)
- }
- })(contextWithTrackingId)
- }
- }
- }
- }
- }),
- interface,
- port
- )(materializer)
+ (extractHost & extractClientIP & trace(tracer)) {
+ case (origin, ip) =>
+ ctx =>
+ val trackingId = extractTrackingId(ctx.request)
+ MDC.put("trackingId", trackingId)
+
+ val updatedStacktrace =
+ (extractStacktrace(ctx.request) ++ Array(appName)).mkString("->")
+ MDC.put("stack", updatedStacktrace)
+
+ storeRequestContextToMdc(ctx.request, origin, ip)
+
+ log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""")
+
+ val contextWithTrackingId =
+ ctx.withRequest(
+ ctx.request
+ .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
+ .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))
+
+ combinedRoute(contextWithTrackingId)
}
}
+ protected def bindHttp(modules: Seq[Module]): Unit = {
+ val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer)
+ }
+
private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = {
MDC.put("origin", origin)
@@ -181,58 +148,6 @@ class DriverApp(appName: String,
}
}
- /**
- * Override me for custom exception handling
- *
- * @return Exception handling route for exception type
- */
- protected def exceptionHandler: PartialFunction[Throwable, Route] = {
-
- case is: IllegalStateException =>
- ctx =>
- log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is)
- errorResponse(ctx, BadRequest, message = is.getMessage, is)(ctx)
-
- case cm: ConcurrentModificationException =>
- ctx =>
- log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm)
- errorResponse(ctx, Conflict, "Resource was changed concurrently, try requesting a newer version", cm)(ctx)
-
- case se: SQLException =>
- ctx =>
- log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se)
- errorResponse(ctx, InternalServerError, "Data access error", se)(ctx)
-
- case t: Throwable =>
- ctx =>
- log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t)
- errorResponse(ctx, InternalServerError, t.getMessage, t)(ctx)
- }
-
- protected def errorResponse[T <: Throwable](ctx: RequestContext,
- statusCode: StatusCode,
- message: String,
- exception: T): Route = {
-
- val trackingId = rest.extractTrackingId(ctx.request)
- val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(ctx.request))
-
- MDC.put("trackingId", trackingId)
-
- optionalHeaderValueByType[Origin](()) { originHeader =>
- val responseHeaders = List[HttpHeader](
- tracingHeader,
- allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*)
- )
-
- respondWithHeaders(responseHeaders) {
- complete(HttpResponse(statusCode, entity = message))
- }
- }
- }
-
protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = {
import spray.json._
import DefaultJsonProtocol._
@@ -334,11 +249,6 @@ class DriverApp(appName: String,
}
object DriverApp {
-
- private def allowOrigin(originHeader: Option[Origin]) =
- `Access-Control-Allow-Origin`(
- originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))
-
implicit def rejectionHandler: RejectionHandler =
RejectionHandler
.newBuilder()
@@ -352,8 +262,8 @@ object DriverApp {
Allow(methods),
`Access-Control-Allow-Methods`(methods),
allowOrigin(originHeader),
- `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*),
- `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*)
+ `Access-Control-Allow-Headers`(AllowedHeaders: _*),
+ `Access-Control-Expose-Headers`(AllowedHeaders: _*)
)) {
complete(s"Supported methods: $names.")
}
@@ -362,5 +272,4 @@ object DriverApp {
complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!")
}
.result()
-
}
diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala
index c6f979f..bbb29f4 100644
--- a/src/main/scala/xyz/driver/core/app/module.scala
+++ b/src/main/scala/xyz/driver/core/app/module.scala
@@ -3,7 +3,8 @@ package xyz.driver.core.app
import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.server.Directives.complete
import akka.http.scaladsl.server.{Route, RouteConcatenation}
-import xyz.driver.core.rest.{NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery}
+import com.typesafe.scalalogging.Logger
+import xyz.driver.core.rest.{DriverRoute, NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery}
import scala.reflect.runtime.universe._
@@ -21,17 +22,22 @@ trait Module {
class EmptyModule extends Module {
override val name: String = "Nothing"
- override def route: Route = complete(StatusCodes.OK)
-
+ override def route: Route = complete(StatusCodes.OK)
override def routeTypes: Seq[Type] = Seq.empty[Type]
}
-class SimpleModule(override val name: String, override val route: Route, routeType: Type) extends Module {
- def routeTypes: Seq[Type] = Seq(routeType)
+class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module {
+ private val driverRoute: DriverRoute = new DriverRoute {
+ override def route: Route = theRoute
+ override val log: Logger = xyz.driver.core.logging.NoLogger
+ }
+
+ override def route: Route = driverRoute.routeWithDefaults
+ override def routeTypes: Seq[Type] = Seq(routeType)
}
/**
- * Module implementation which may be used to composed a few
+ * Module implementation which may be used to compose multiple modules
*
* @param name more general name of the composite module,
* must be provided as there is no good way to automatically
@@ -39,12 +45,8 @@ class SimpleModule(override val name: String, override val route: Route, routeTy
* @param modules modules to compose into a single one
*/
class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation {
-
- override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*)
-
+ override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*)
override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes)
-
- override def activate(): Unit = modules.foreach(_.activate())
-
- override def deactivate(): Unit = modules.reverse.foreach(_.deactivate())
+ override def activate(): Unit = modules.foreach(_.activate())
+ override def deactivate(): Unit = modules.reverse.foreach(_.deactivate())
}
diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
new file mode 100644
index 0000000..eb9a31a
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala
@@ -0,0 +1,109 @@
+package xyz.driver.core.rest
+
+import java.sql.SQLException
+
+import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.server.Directives._
+import akka.http.scaladsl.server.{Directive0, ExceptionHandler, RequestContext, Route}
+import com.typesafe.scalalogging.Logger
+import org.slf4j.MDC
+import xyz.driver.core.rest
+import xyz.driver.core.rest.errors._
+
+import scala.compat.Platform.ConcurrentModificationException
+
+trait DriverRoute {
+ def log: Logger
+
+ def route: Route
+
+ def routeWithDefaults: Route = {
+ (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler)))(route)
+ }
+
+ protected def defaultResponseHeaders: Directive0 = {
+ (extractRequest & optionalHeaderValueByType[Origin](())) tflatMap {
+ case (request, originHeader) =>
+ val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, rest.extractTrackingId(request))
+ val responseHeaders = List[HttpHeader](
+ tracingHeader,
+ allowOrigin(originHeader),
+ `Access-Control-Allow-Headers`(AllowedHeaders: _*),
+ `Access-Control-Expose-Headers`(AllowedHeaders: _*)
+ )
+
+ respondWithHeaders(responseHeaders)
+ }
+ }
+
+ /**
+ * Override me for custom exception handling
+ *
+ * @return Exception handling route for exception type
+ */
+ protected def exceptionHandler: PartialFunction[Throwable, Route] = {
+ case serviceException: ServiceException =>
+ serviceExceptionHandler(serviceException)
+
+ case is: IllegalStateException =>
+ ctx =>
+ log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is)
+ errorResponse(ctx, StatusCodes.BadRequest, message = is.getMessage, is)(ctx)
+
+ case cm: ConcurrentModificationException =>
+ ctx =>
+ log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm)
+ errorResponse(ctx,
+ StatusCodes.Conflict,
+ "Resource was changed concurrently, try requesting a newer version",
+ cm)(ctx)
+
+ case se: SQLException =>
+ ctx =>
+ log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se)
+ errorResponse(ctx, StatusCodes.InternalServerError, "Data access error", se)(ctx)
+
+ case t: Exception =>
+ ctx =>
+ log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t)
+ errorResponse(ctx, StatusCodes.InternalServerError, t.getMessage, t)(ctx)
+ }
+
+ protected def serviceExceptionHandler(serviceException: ServiceException): Route = {
+ val statusCode = serviceException match {
+ case e: InvalidInputException =>
+ log.info("Invalid client input error", e)
+ StatusCodes.BadRequest
+ case e: InvalidActionException =>
+ log.info("Invalid client action error", e)
+ StatusCodes.Forbidden
+ case e: ResourceNotFoundException =>
+ log.info("Resource not found error", e)
+ StatusCodes.NotFound
+ case e: ExternalServiceException =>
+ log.error("Error while calling another service", e)
+ StatusCodes.InternalServerError
+ case e: ExternalServiceTimeoutException =>
+ log.error("Service timeout error", e)
+ StatusCodes.GatewayTimeout
+ case e: DatabaseException =>
+ log.error("Database error", e)
+ StatusCodes.InternalServerError
+ }
+
+ { (ctx: RequestContext) =>
+ errorResponse(ctx, statusCode, serviceException.message, serviceException)(ctx)
+ }
+ }
+
+ protected def errorResponse[T <: Exception](ctx: RequestContext,
+ statusCode: StatusCode,
+ message: String,
+ exception: T): Route = {
+ val trackingId = rest.extractTrackingId(ctx.request)
+ MDC.put("trackingId", trackingId)
+ complete(HttpResponse(statusCode, entity = message))
+ }
+}
diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala
index 1e95811..376b154 100644
--- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala
+++ b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala
@@ -4,9 +4,12 @@ import akka.actor.ActorSystem
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.unmarshalling.Unmarshal
+import akka.stream.Materializer
+import akka.stream.scaladsl.TcpIdleTimeoutException
import com.typesafe.scalalogging.Logger
import org.slf4j.MDC
import xyz.driver.core.Name
+import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException}
import xyz.driver.core.time.provider.TimeProvider
import scala.concurrent.{ExecutionContext, Future}
@@ -55,18 +58,27 @@ class HttpRestServiceTransport(applicationName: Name[App],
log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t)
}(executionContext)
- response
+ response.recoverWith {
+ case _: TcpIdleTimeoutException =>
+ val serviceCalled = s"${requestStub.method} ${requestStub.uri}"
+ Future.failed(ExternalServiceTimeoutException(serviceCalled))
+ case t: Throwable => Future.failed(t)
+ }
}
- def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]] = {
+ def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)(
+ implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = {
- sendRequestGetResponse(context)(requestStub) map { response =>
+ sendRequestGetResponse(context)(requestStub) flatMap { response =>
if (response.status == StatusCodes.NotFound) {
- Unmarshal(HttpEntity.Empty: ResponseEntity)
+ Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity))
} else if (response.status.isFailure()) {
- throw new Exception(s"Http status is failure ${response.status} for ${requestStub.method} ${requestStub.uri}")
+ val serviceCalled = s"${requestStub.method} ${requestStub.uri}"
+ Unmarshal(response.entity).to[String] flatMap { error =>
+ Future.failed(ExternalServiceException(serviceCalled, error))
+ }
} else {
- Unmarshal(response.entity)
+ Future.successful(Unmarshal(response.entity))
}
}
}
diff --git a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala
new file mode 100644
index 0000000..e91a3c2
--- /dev/null
+++ b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala
@@ -0,0 +1,23 @@
+package xyz.driver.core.rest.errors
+
+sealed abstract class ServiceException extends Exception {
+ def message: String
+}
+
+final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException
+
+final case class InvalidActionException(override val message: String = "This action is not allowed")
+ extends ServiceException
+
+final case class ResourceNotFoundException(override val message: String = "Resource not found")
+ extends ServiceException
+
+final case class ExternalServiceException(serviceName: String, serviceMessage: String) extends ServiceException {
+ override def message = s"Error while calling '$serviceName': $serviceMessage"
+}
+
+final case class ExternalServiceTimeoutException(serviceName: String) extends ServiceException {
+ override def message = s"$serviceName took too long to respond"
+}
+
+final case class DatabaseException(override val message: String = "Database access error") extends ServiceException
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala
index 942ca3a..531cd8a 100644
--- a/src/main/scala/xyz/driver/core/rest/package.scala
+++ b/src/main/scala/xyz/driver/core/rest/package.scala
@@ -6,6 +6,7 @@ import akka.http.scaladsl.model.{HttpRequest, HttpResponse, ResponseEntity, Stat
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server._
import akka.http.scaladsl.unmarshalling.Unmarshal
+import akka.stream.Materializer
import akka.stream.scaladsl.Flow
import akka.util.ByteString
import xyz.driver.tracing.TracingDirectives
@@ -25,7 +26,8 @@ trait ServiceTransport {
def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse]
- def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest): Future[Unmarshal[ResponseEntity]]
+ def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)(
+ implicit mat: Materializer): Future[Unmarshal[ResponseEntity]]
}
final case class Pagination(pageSize: Int, pageNumber: Int)
diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
new file mode 100644
index 0000000..c239fb6
--- /dev/null
+++ b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala
@@ -0,0 +1,89 @@
+package xyz.driver.core.rest
+
+import akka.http.scaladsl.model.StatusCodes
+import akka.http.scaladsl.server.Directives.{complete => akkaComplete}
+import akka.http.scaladsl.server.Route
+import akka.http.scaladsl.testkit.ScalatestRouteTest
+import com.typesafe.scalalogging.Logger
+import org.scalatest.{AsyncFlatSpec, Matchers}
+import xyz.driver.core.logging.NoLogger
+import xyz.driver.core.rest.errors._
+
+import scala.concurrent.Future
+
+class DriverRouteTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers {
+ class TestRoute(override val route: Route) extends DriverRoute {
+ override def log: Logger = NoLogger
+ }
+
+ "DriverRoute" should "respond with 200 OK for a basic route" in {
+ val route = new TestRoute(akkaComplete(StatusCodes.OK))
+
+ Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.OK
+ }
+ }
+
+ it should "respond with a 400 for an InvalidInputException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.BadRequest
+ responseAs[String] shouldBe "Invalid input"
+ }
+ }
+
+ it should "respond with a 400 for InvalidActionException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.Forbidden
+ responseAs[String] shouldBe "This action is not allowed"
+ }
+ }
+
+ it should "respond with a 404 for ResourceNotFoundException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.NotFound
+ responseAs[String] shouldBe "Resource not found"
+ }
+ }
+
+ it should "respond with a 500 for ExternalServiceException" in {
+ val error = ExternalServiceException("GET /api/v1/users/", "Permission denied")
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[String] shouldBe "Error while calling 'GET /api/v1/users/': Permission denied"
+ }
+ }
+
+ it should "respond with a 503 for ExternalServiceTimeoutException" in {
+ val error = ExternalServiceTimeoutException("GET /api/v1/users/")
+ val route = new TestRoute(akkaComplete(Future.failed[String](error)))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.GatewayTimeout
+ responseAs[String] shouldBe "GET /api/v1/users/ took too long to respond"
+ }
+ }
+
+ it should "respond with a 500 for DatabaseException" in {
+ val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException())))
+
+ Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check {
+ handled shouldBe true
+ status shouldBe StatusCodes.InternalServerError
+ responseAs[String] shouldBe "Database access error"
+ }
+ }
+}
diff --git a/src/test/scala/xyz/driver/core/RestTest.scala b/src/test/scala/xyz/driver/core/rest/RestTest.scala
index efb9d07..2c3fb7f 100644
--- a/src/test/scala/xyz/driver/core/RestTest.scala
+++ b/src/test/scala/xyz/driver/core/rest/RestTest.scala
@@ -1,8 +1,7 @@
package xyz.driver.core.rest
-import org.scalatest.{FlatSpec, Matchers}
-
import akka.util.ByteString
+import org.scalatest.{FlatSpec, Matchers}
class RestTest extends FlatSpec with Matchers {
"`escapeScriptTags` function" should "escap script tags properly" in {