diff options
Diffstat (limited to 'src')
6 files changed, 208 insertions, 4 deletions
diff --git a/src/main/scala/xyz/driver/core/app/DriverApp.scala b/src/main/scala/xyz/driver/core/app/DriverApp.scala index 751bef7..5297c90 100644 --- a/src/main/scala/xyz/driver/core/app/DriverApp.scala +++ b/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -44,6 +44,7 @@ class DriverApp(appName: String, scheme: String = "http", port: Int = 8080, tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) { + self => import DriverApp._ implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) @@ -69,12 +70,16 @@ class DriverApp(appName: String, private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = request.headers.find(_.name().toLowerCase === headerName).map(_.value()) - protected def appRoute: Route = { + def appRoute: Route = { val serviceTypes = modules.flatMap(_.routeTypes) val swaggerService = swaggerOverride(serviceTypes) val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI val versionRt = versionRoute(version, gitHash, time.currentTime()) - val combinedRoute = modules.map(_.route).foldLeft(versionRt ~ healthRoute ~ swaggerRoute)(_ ~ _) + val basicRoutes = new DriverRoute { + override def log: Logger = self.log + override def route: Route = versionRt ~ healthRoute ~ swaggerRoute + } + val combinedRoute = modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) (extractHost & extractClientIP & trace(tracer)) { case (origin, ip) => diff --git a/src/main/scala/xyz/driver/core/app/init.scala b/src/main/scala/xyz/driver/core/app/init.scala new file mode 100644 index 0000000..36eaeda --- /dev/null +++ b/src/main/scala/xyz/driver/core/app/init.scala @@ -0,0 +1,118 @@ +package xyz.driver.core.app + +import java.nio.file.{Files, Paths} +import java.util.concurrent.{Executor, Executors} + +import akka.actor.ActorSystem +import akka.stream.ActorMaterializer +import com.typesafe.config.{Config, ConfigFactory} +import com.typesafe.scalalogging.Logger +import org.slf4j.LoggerFactory +import xyz.driver.core.logging.MdcExecutionContext +import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} +import xyz.driver.tracing.{GoogleTracer, NoTracer, Tracer} + +import scala.concurrent.ExecutionContext +import scala.util.Try + +object init { + + type RequiredBuildInfo = { + val name: String + val version: String + val gitHeadCommit: scala.Option[String] + } + + case class ApplicationContext(config: Config, time: TimeProvider, log: Logger) + + /** NOTE: This needs to be the first that is run when application starts. + * Otherwise if another command causes the logger to be instantiated, + * it will default to logback.xml, and not honor this configuration + */ + def configureLogging() = { + scala.sys.env.get("JSON_LOGGING") match { + case Some("true") => + System.setProperty("logback.configurationFile", "deployed-logback.xml") + case _ => + System.setProperty("logback.configurationFile", "logback.xml") + } + } + + def getEnvironmentSpecificConfig(): Config = { + scala.sys.env.get("APPLICATION_CONFIG_TYPE") match { + case Some("deployed") => + ConfigFactory.load(this.getClass.getClassLoader, "deployed-application.conf") + case _ => + xyz.driver.core.config.loadDefaultConfig + } + } + + def configureTracer(actorSystem: ActorSystem, applicationContext: ApplicationContext): Tracer = { + + val serviceAccountKeyFile = + Paths.get(applicationContext.config.getString("tracing.google.serviceAccountKeyfile")) + + if (Files.exists(serviceAccountKeyFile)) { + val materializer = ActorMaterializer()(actorSystem) + new GoogleTracer( + projectId = applicationContext.config.getString("tracing.google.projectId"), + serviceAccountFile = serviceAccountKeyFile + )(actorSystem, materializer) + } else { + applicationContext.log.warn(s"Tracing file $serviceAccountKeyFile was not found, using NoTracer!") + NoTracer + } + } + + def serviceActorSystem(serviceName: String, executionContext: ExecutionContext, config: Config) = { + val actorSystem = + ActorSystem(s"$serviceName-actors", Option(config), Option.empty[ClassLoader], Option(executionContext)) + + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = Try(actorSystem.terminate()) + }) + + actorSystem + } + + def toMdcExecutionContext(executor: Executor) = + new MdcExecutionContext(ExecutionContext.fromExecutor(executor)) + + def newFixedMdcExecutionContext(capacity: Int): MdcExecutionContext = + toMdcExecutionContext(Executors.newFixedThreadPool(capacity)) + + def defaultApplicationContext() = { + val config = getEnvironmentSpecificConfig() + + val time = new SystemTimeProvider() + val log = Logger(LoggerFactory.getLogger(classOf[DriverApp])) + + ApplicationContext(config, time, log) + } + + def createDefaultApplication(modules: Seq[Module], + buildInfo: RequiredBuildInfo, + actorSystem: ActorSystem, + tracer: Tracer, + context: ApplicationContext) = { + val scheme = context.config.getString("application.scheme") + val baseUrl = context.config.getString("application.baseUrl") + val port = context.config.getInt("application.port") + + new DriverApp( + buildInfo.name, + buildInfo.version, + buildInfo.gitHeadCommit.getOrElse("None"), + modules = modules, + context.time, + context.log, + context.config, + interface = "::0", + baseUrl, + scheme, + port, + tracer + )(actorSystem, actorSystem.dispatcher) + } + +} diff --git a/src/main/scala/xyz/driver/core/app/module.scala b/src/main/scala/xyz/driver/core/app/module.scala index bbb29f4..7be38eb 100644 --- a/src/main/scala/xyz/driver/core/app/module.scala +++ b/src/main/scala/xyz/driver/core/app/module.scala @@ -3,7 +3,9 @@ package xyz.driver.core.app import akka.http.scaladsl.model.StatusCodes import akka.http.scaladsl.server.Directives.complete import akka.http.scaladsl.server.{Route, RouteConcatenation} +import com.typesafe.config.Config import com.typesafe.scalalogging.Logger +import xyz.driver.core.database.Database import xyz.driver.core.rest.{DriverRoute, NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} import scala.reflect.runtime.universe._ @@ -36,6 +38,22 @@ class SimpleModule(override val name: String, theRoute: Route, routeType: Type) override def routeTypes: Seq[Type] = Seq(routeType) } +trait SingleDatabaseModule { self: Module => + + val databaseName: String + val config: Config + + def database = Database.fromConfig(config, databaseName) + + override def deactivate(): Unit = { + try { + database.database.close() + } finally { + self.deactivate() + } + } +} + /** * Module implementation which may be used to compose multiple modules * diff --git a/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala b/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala index 9f8db3e..df21b48 100644 --- a/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala +++ b/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala @@ -16,8 +16,7 @@ class MdcExecutionContext(executionContext: ExecutionContext) extends ExecutionC executionContext.execute(new Runnable { def run(): Unit = { // copy caller thread diagnostic context to execution thread - // scalastyle:off - if (callerMdc != null) MDC.setContextMap(callerMdc) + Option(callerMdc).foreach(MDC.setContextMap) try { runnable.run() } finally { diff --git a/src/main/scala/xyz/driver/core/swagger.scala b/src/main/scala/xyz/driver/core/swagger.scala index a97e0ac..44ca6e1 100644 --- a/src/main/scala/xyz/driver/core/swagger.scala +++ b/src/main/scala/xyz/driver/core/swagger.scala @@ -15,6 +15,13 @@ import spray.json._ object swagger { + def configureCustomSwaggerModels(customPropertiesExamples: Map[Class[_], Property], + customObjectsExamples: Map[Class[_], JsValue]) = { + ModelConverters + .getInstance() + .addConverter(new CustomSwaggerJsonConverter(Json.mapper(), customPropertiesExamples, customObjectsExamples)) + } + object CustomSwaggerJsonConverter { def stringProperty(pattern: Option[String] = None, example: Option[String] = None): Property = { diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..82cc8cd --- /dev/null +++ b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,57 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpMethods, StatusCodes} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.settings.RoutingSettings +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import org.scalatest.{FlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, Module} + +import scala.reflect.runtime.universe._ + +class DriverAppTest extends FlatSpec with ScalatestRouteTest with Matchers { + class TestRoute extends DriverRoute { + override def log: Logger = xyz.driver.core.logging.NoLogger + override def route: Route = path("api" / "v1" / "test")(post(complete("OK"))) + } + + val module: Module = new Module { + val testRoute = new TestRoute + override def route: Route = testRoute.routeWithDefaults + override def routeTypes: Seq[Type] = Seq(typeOf[TestRoute]) + override val name: String = "test-module" + } + + val app: DriverApp = new DriverApp( + appName = "test-app", + version = "0.1", + gitHash = "deadb33f", + modules = Seq(module) + ) + + val config: Config = xyz.driver.core.config.loadDefaultConfig + val routingSettings: RoutingSettings = RoutingSettings(config) + val appRoute: Route = Route.seal(app.appRoute)(routingSettings = routingSettings, rejectionHandler = DriverApp.rejectionHandler) + + "DriverApp" should "respond with the correct CORS headers for the swagger OPTIONS route" in { + Options(s"/api-docs/swagger.json") ~> appRoute ~> check { + status shouldBe StatusCodes.OK + info(response.toString()) + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) + headers should contain(`Access-Control-Allow-Methods`(HttpMethods.GET)) + } + } + + it should "respond with the correct CORS headers for the test route" in { + Options(s"/api/v1/test") ~> appRoute ~> check { + status shouldBe StatusCodes.OK + info(response.toString()) + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) + headers should contain(`Access-Control-Allow-Methods`(HttpMethods.GET, HttpMethods.POST)) + } + } +} |