aboutsummaryrefslogblamecommitdiff
path: root/src/main/scala/xyz/driver/core/app/DriverApp.scala
blob: cb3a38ed99bee4cc4ce5e828bd8004523d45f1b5 (plain) (tree)
1
2
3
4
5

                           

                                                                
                                                            





















                                                                         
                                  
                                                 

















                                                                                                                   
                    
























                                                                                                         
                                   




                                                                           



































                                                                                                                  
           

         


     



                                                                                          











































                                                                                                         


































































































                                                                                                           































                                                                                                 
package xyz.driver.core.app

import akka.actor.ActorSystem
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport
import akka.http.scaladsl.model.StatusCodes.MethodNotAllowed
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.RouteResult.route2HandlerFlow
import akka.http.scaladsl.server._
import akka.http.scaladsl.{Http, HttpExt}
import akka.stream.ActorMaterializer
import com.github.swagger.akka.SwaggerHttpService.{logger, toJavaTypeSet}
import com.typesafe.config.Config
import com.typesafe.scalalogging.Logger
import io.swagger.models.Scheme
import io.swagger.util.Json
import org.slf4j.{LoggerFactory, MDC}
import xyz.driver.core
import xyz.driver.core.rest
import xyz.driver.core.rest.{ContextHeaders, Swagger}
import xyz.driver.core.stats.SystemStats
import xyz.driver.core.time.Time
import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider}
import xyz.driver.tracing.TracingDirectives.trace
import xyz.driver.tracing.{NoTracer, Tracer}

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext}
import scala.reflect.runtime.universe._
import scala.util.Try
import scala.util.control.NonFatal
import scalaz.Scalaz.stringInstance
import scalaz.syntax.equal._

class DriverApp(appName: String,
                version: String,
                gitHash: String,
                modules: Seq[Module],
                time: TimeProvider = new SystemTimeProvider(),
                log: Logger = Logger(LoggerFactory.getLogger(classOf[DriverApp])),
                config: Config = core.config.loadDefaultConfig,
                interface: String = "::0",
                baseUrl: String = "localhost:8080",
                scheme: String = "http",
                port: Int = 8080,
                tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) {
  import DriverApp._

  implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem)
  private lazy val http: HttpExt                            = Http()(actorSystem)
  val appEnvironment: String                                = config.getString("application.environment")

  def run(): Unit = {
    activateServices(modules)
    scheduleServicesDeactivation(modules)
    bindHttp(modules)
    Console.print(s"${this.getClass.getName} App is started\n")
  }

  def stop(): Unit = {
    http.shutdownAllConnectionPools().onComplete { _ =>
      Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer
      val _                 = actorSystem.terminate()
      val terminated        = Await.result(actorSystem.whenTerminated, 30.seconds)
      val addressTerminated = if (terminated.addressTerminated) "is" else "is not"
      Console.print(s"${this.getClass.getName} App $addressTerminated stopped ")
    }
  }

  private def extractHeader(request: HttpRequest)(headerName: String): Option[String] =
    request.headers.find(_.name().toLowerCase === headerName).map(_.value())

  protected def appRoute: Route = {
    val serviceTypes   = modules.flatMap(_.routeTypes)
    val swaggerService = swaggerOverride(serviceTypes)
    val swaggerRoutes  = swaggerService.routes ~ swaggerService.swaggerUI
    val versionRt      = versionRoute(version, gitHash, time.currentTime())

    extractHost { origin =>
      extractClientIP { ip =>
        optionalHeaderValueByType[Origin](()) { originHeader =>
          trace(tracer) { ctx =>
            val trackingId = rest.extractTrackingId(ctx.request)
            MDC.put("trackingId", trackingId)

            val updatedStacktrace =
              (rest.extractStacktrace(ctx.request) ++ Array(appName)).mkString("->")
            MDC.put("stack", updatedStacktrace)

            storeRequestContextToMdc(ctx.request, origin, ip)

            val trackingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId)

            val responseHeaders = List[HttpHeader](
              trackingHeader,
              allowOrigin(originHeader),
              `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*),
              `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*)
            )

            log.info(s"""Received request {"method":"${ctx.request.method.value}","url": "${ctx.request.uri}"}""")

            val contextWithTrackingId =
              ctx.withRequest(
                ctx.request
                  .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId))
                  .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace)))

            respondWithHeaders(responseHeaders) {
              modules
                .flatMap(_.routes)
                .map(_.routeWithDefaults)
                .foldLeft(versionRt ~ healthRoute ~ swaggerRoutes)(_ ~ _)
            }(contextWithTrackingId)
          }
        }
      }
    }
  }

  protected def bindHttp(modules: Seq[Module]): Unit = {
    val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer)
  }

  private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = {

    MDC.put("origin", origin)
    MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown"))
    MDC.put("remoteHost", ip.toOption.map(_.getHostName).getOrElse("unknown"))

    MDC.put("xForwardedFor",
            extractHeader(request)("x-forwarded-for")
              .orElse(extractHeader(request)("x_forwarded_for"))
              .getOrElse("unknown"))
    MDC.put("remoteAddress", extractHeader(request)("remote-address").getOrElse("unknown"))
    MDC.put("userAgent", extractHeader(request)("user-agent").getOrElse("unknown"))
  }

  protected def swaggerOverride(apiTypes: Seq[Type]): Swagger = {
    new Swagger(baseUrl, Scheme.forValue(scheme), version, actorSystem, apiTypes, config) {
      override def generateSwaggerJson: String = {
        import io.swagger.models.Swagger

        import scala.collection.JavaConverters._

        try {
          val swagger: Swagger = reader.read(toJavaTypeSet(apiTypes).asJava)

          // Removing trailing spaces
          swagger.setPaths(
            swagger.getPaths.asScala
              .map {
                case (key, path) =>
                  key.trim -> path
              }
              .toMap
              .asJava)

          Json.pretty().writeValueAsString(swagger)
        } catch {
          case NonFatal(t) =>
            logger.error("Issue with creating swagger.json", t)
            throw t
        }
      }
    }
  }

  protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = {
    import spray.json._
    import DefaultJsonProtocol._
    import SprayJsonSupport._

    path("version") {
      val currentTime = time.currentTime().millis
      complete(
        Map(
          "version"      -> version.toJson,
          "gitHash"      -> gitHash.toJson,
          "modules"      -> modules.map(_.name).toJson,
          "dependencies" -> collectAppDependencies().toJson,
          "startupTime"  -> startupTime.millis.toString.toJson,
          "serverTime"   -> currentTime.toString.toJson,
          "uptime"       -> (currentTime - startupTime.millis).toString.toJson
        ).toJson)
    }
  }

  protected def collectAppDependencies(): Map[String, String] = {

    def serviceWithLocation(serviceName: String): (String, String) =
      serviceName -> Try(config.getString(s"services.$serviceName.baseUrl")).getOrElse("not-detected")

    modules.flatMap(module => module.serviceDiscovery.getUsedServices.map(serviceWithLocation).toSeq).toMap
  }

  protected def healthRoute: Route = {
    import spray.json._
    import DefaultJsonProtocol._
    import SprayJsonSupport._
    import spray.json._

    val memoryUsage = SystemStats.memoryUsage
    val gcStats     = SystemStats.garbageCollectorStats

    path("health") {
      complete(
        Map(
          "availableProcessors" -> SystemStats.availableProcessors.toJson,
          "memoryUsage" -> Map(
            "free"  -> memoryUsage.free.toJson,
            "total" -> memoryUsage.total.toJson,
            "max"   -> memoryUsage.max.toJson
          ).toJson,
          "gcStats" -> Map(
            "garbageCollectionTime"   -> gcStats.garbageCollectionTime.toJson,
            "totalGarbageCollections" -> gcStats.totalGarbageCollections.toJson
          ).toJson,
          "fileSystemSpace" -> SystemStats.fileSystemSpace.map { f =>
            Map("path"        -> f.path.toJson,
                "freeSpace"   -> f.freeSpace.toJson,
                "totalSpace"  -> f.totalSpace.toJson,
                "usableSpace" -> f.usableSpace.toJson)
          }.toJson,
          "operatingSystem" -> SystemStats.operatingSystemStats.toJson
        ))
    }
  }

  /**
    * Initializes services
    */
  protected def activateServices(services: Seq[Module]): Unit = {
    services.foreach { service =>
      Console.print(s"Service ${service.name} starts ...")
      try {
        service.activate()
      } catch {
        case t: Throwable =>
          log.error(s"Service ${service.name} failed to activate", t)
          Console.print(" Failed! (check log)")
      }
      Console.print(" Done\n")
    }
  }

  /**
    * Schedules services to be deactivated on the app shutdown
    */
  protected def scheduleServicesDeactivation(services: Seq[Module]): Unit = {
    Runtime.getRuntime.addShutdownHook(new Thread() {
      override def run(): Unit = {
        services.foreach { service =>
          Console.print(s"Service ${service.name} shutting down ...\n")
          try {
            service.deactivate()
          } catch {
            case t: Throwable =>
              log.error(s"Service ${service.name} failed to deactivate", t)
              Console.print(" Failed! (check log)")
          }
          Console.print(s"Service ${service.name} is shut down\n")
        }
      }
    })
  }
}

object DriverApp {

  private def allowOrigin(originHeader: Option[Origin]) =
    `Access-Control-Allow-Origin`(
      originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*)))

  implicit def rejectionHandler: RejectionHandler =
    RejectionHandler
      .newBuilder()
      .handleAll[MethodRejection] { rejections =>
        val methods    = rejections map (_.supported)
        lazy val names = methods map (_.name) mkString ", "

        options { ctx =>
          optionalHeaderValueByType[Origin](()) { originHeader =>
            respondWithHeaders(List[HttpHeader](
              Allow(methods),
              `Access-Control-Allow-Methods`(methods),
              allowOrigin(originHeader),
              `Access-Control-Allow-Headers`(rest.AllowedHeaders: _*),
              `Access-Control-Expose-Headers`(rest.AllowedHeaders: _*)
            )) {
              complete(s"Supported methods: $names.")
            }
          }(ctx)
        } ~
          complete(MethodNotAllowed -> s"HTTP method not allowed, supported methods: $names!")
      }
      .result()

}