diff options
Diffstat (limited to 'jvm/src/main/scala/xyz/driver/core')
52 files changed, 4568 insertions, 0 deletions
diff --git a/jvm/src/main/scala/xyz/driver/core/app/DriverApp.scala b/jvm/src/main/scala/xyz/driver/core/app/DriverApp.scala new file mode 100644 index 0000000..6dd98e3 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/app/DriverApp.scala @@ -0,0 +1,294 @@ +package xyz.driver.core.app + +import akka.actor.ActorSystem +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server.RouteResult._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.{Http, HttpExt} +import akka.stream.ActorMaterializer +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import org.slf4j.{LoggerFactory, MDC} +import xyz.driver.core +import xyz.driver.core.rest._ +import xyz.driver.core.stats.SystemStats +import xyz.driver.core.time.Time +import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} +import xyz.driver.tracing.TracingDirectives._ +import xyz.driver.tracing._ + +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext} +import scala.util.Try +import scalaz.Scalaz.stringInstance +import scalaz.syntax.equal._ + +class DriverApp( + appName: String, + version: String, + gitHash: String, + modules: Seq[Module], + time: TimeProvider = new SystemTimeProvider(), + log: Logger = Logger(LoggerFactory.getLogger(classOf[DriverApp])), + config: Config = core.config.loadDefaultConfig, + interface: String = "::0", + baseUrl: String = "localhost:8080", + scheme: String = "http", + port: Int = 8080, + tracer: Tracer = NoTracer)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) { + self => + + implicit private lazy val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private lazy val http: HttpExt = Http()(actorSystem) + val appEnvironment: String = config.getString("application.environment") + + def run(): Unit = { + activateServices(modules) + scheduleServicesDeactivation(modules) + bindHttp(modules) + Console.print(s"${this.getClass.getName} App is started\n") + } + + def stop(): Unit = { + http.shutdownAllConnectionPools().onComplete { _ => + Await.result(tracer.close(), 15.seconds) // flush out any remaining traces from the buffer + val terminated = Await.result(actorSystem.terminate(), 30.seconds) + val addressTerminated = if (terminated.addressTerminated) "is" else "is not" + Console.print(s"${this.getClass.getName} App $addressTerminated stopped ") + } + } + + protected lazy val allowedCorsDomainSuffixes: Set[HttpOrigin] = { + import scala.collection.JavaConverters._ + config + .getConfigList("application.cors.allowedOrigins") + .asScala + .map { c => + HttpOrigin(c.getString("scheme"), Host(c.getString("hostSuffix"))) + }(scala.collection.breakOut) + } + + protected lazy val defaultCorsAllowedMethods: Set[HttpMethod] = { + import scala.collection.JavaConverters._ + config.getStringList("application.cors.allowedMethods").asScala.toSet.flatMap(HttpMethods.getForKey) + } + + protected lazy val defaultCorsAllowedOrigin: Origin = { + Origin(allowedCorsDomainSuffixes.to[collection.immutable.Seq]) + } + + protected def corsAllowedOriginHeader(origin: Option[Origin]): HttpHeader = { + val allowedOrigin = + origin + .filter { requestOrigin => + allowedCorsDomainSuffixes.exists { allowedOriginSuffix => + requestOrigin.origins.exists(o => + o.scheme == allowedOriginSuffix.scheme && + o.host.host.address.endsWith(allowedOriginSuffix.host.host.address())) + } + } + .getOrElse(defaultCorsAllowedOrigin) + + `Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigin.origins: _*)) + } + + protected def respondWithAllCorsHeaders: Directive0 = { + respondWithCorsAllowedHeaders tflatMap { _ => + respondWithCorsAllowedMethodHeaders(defaultCorsAllowedMethods) tflatMap { _ => + optionalHeaderValueByType[Origin](()) flatMap { origin => + respondWithHeader(corsAllowedOriginHeader(origin)) + } + } + } + } + + private def extractHeader(request: HttpRequest)(headerName: String): Option[String] = + request.headers.find(_.name().toLowerCase === headerName).map(_.value()) + + protected def defaultOptionsRoute: Route = options { + respondWithAllCorsHeaders { + complete("OK") + } + } + + def appRoute: Route = { + val serviceTypes = modules.flatMap(_.routeTypes) + val swaggerService = new Swagger(baseUrl, Scheme.forValue(scheme) :: Nil, version, serviceTypes, config, log) + val swaggerRoute = swaggerService.routes ~ swaggerService.swaggerUI + val versionRt = versionRoute(version, gitHash, time.currentTime()) + val basicRoutes = new DriverRoute { + override def log: Logger = self.log + override def route: Route = versionRt ~ healthRoute ~ swaggerRoute + } + val combinedRoute = + Route.seal(modules.map(_.route).foldLeft(basicRoutes.routeWithDefaults)(_ ~ _) ~ defaultOptionsRoute) + + (extractHost & extractClientIP & trace(tracer) & handleRejections(authenticationRejectionHandler)) { + case (origin, ip) => + ctx => + val trackingId = extractTrackingId(ctx.request) + MDC.put("trackingId", trackingId) + + val updatedStacktrace = + (extractStacktrace(ctx.request) ++ Array(appName)).mkString("->") + MDC.put("stack", updatedStacktrace) + + storeRequestContextToMdc(ctx.request, origin, ip) + + log.info(s"""Received request ${ctx.request.method.value} ${ctx.request.uri} (trace: $trackingId)""") + + val contextWithTrackingId = + ctx.withRequest( + ctx.request + .addHeader(RawHeader(ContextHeaders.TrackingIdHeader, trackingId)) + .addHeader(RawHeader(ContextHeaders.StacktraceHeader, updatedStacktrace))) + + val logResponses = mapRouteResult { + case c @ Complete(response) => + log.info( + s"Responded to ${ctx.request.method.value} ${ctx.request.uri} " + + s"with ${response.status.toString} (trace: $trackingId)") + c + case r @ Rejected(rejections) => + log.warn( + s"Request ${ctx.request.method.value} ${ctx.request.uri} " + + s"(trace: $trackingId) is rejected:\n${rejections.mkString(",\n")}") + r + } + + respondWithAllCorsHeaders(logResponses(combinedRoute))(contextWithTrackingId) + } + } + + protected def authenticationRejectionHandler: RejectionHandler = + RejectionHandler + .newBuilder() + .handle { + case AuthenticationFailedRejection(_, challenge) => + complete(HttpResponse(StatusCodes.Unauthorized, entity = challenge.realm)) + } + .result() + + protected def bindHttp(modules: Seq[Module]): Unit = { + val _ = http.bindAndHandle(route2HandlerFlow(appRoute), interface, port)(materializer) + } + + private def storeRequestContextToMdc(request: HttpRequest, origin: String, ip: RemoteAddress): Unit = { + + MDC.put("origin", origin) + MDC.put("ip", ip.toOption.map(_.getHostAddress).getOrElse("unknown")) + MDC.put("remoteHost", ip.toOption.map(_.getHostName).getOrElse("unknown")) + + MDC.put( + "xForwardedFor", + extractHeader(request)("x-forwarded-for") + .orElse(extractHeader(request)("x_forwarded_for")) + .getOrElse("unknown")) + MDC.put("remoteAddress", extractHeader(request)("remote-address").getOrElse("unknown")) + MDC.put("userAgent", extractHeader(request)("user-agent").getOrElse("unknown")) + } + + protected def versionRoute(version: String, gitHash: String, startupTime: Time): Route = { + import spray.json._ + import DefaultJsonProtocol._ + import SprayJsonSupport._ + + path("version") { + val currentTime = time.currentTime().millis + complete( + Map( + "version" -> version.toJson, + "gitHash" -> gitHash.toJson, + "modules" -> modules.map(_.name).toJson, + "dependencies" -> collectAppDependencies().toJson, + "startupTime" -> startupTime.millis.toString.toJson, + "serverTime" -> currentTime.toString.toJson, + "uptime" -> (currentTime - startupTime.millis).toString.toJson + ).toJson) + } + } + + protected def collectAppDependencies(): Map[String, String] = { + + def serviceWithLocation(serviceName: String): (String, String) = + serviceName -> Try(config.getString(s"services.$serviceName.baseUrl")).getOrElse("not-detected") + + modules.flatMap(module => module.serviceDiscovery.getUsedServices.map(serviceWithLocation).toSeq).toMap + } + + protected def healthRoute: Route = { + import spray.json._ + import DefaultJsonProtocol._ + import SprayJsonSupport._ + import spray.json._ + + val memoryUsage = SystemStats.memoryUsage + val gcStats = SystemStats.garbageCollectorStats + + path("health") { + complete( + Map( + "availableProcessors" -> SystemStats.availableProcessors.toJson, + "memoryUsage" -> Map( + "free" -> memoryUsage.free.toJson, + "total" -> memoryUsage.total.toJson, + "max" -> memoryUsage.max.toJson + ).toJson, + "gcStats" -> Map( + "garbageCollectionTime" -> gcStats.garbageCollectionTime.toJson, + "totalGarbageCollections" -> gcStats.totalGarbageCollections.toJson + ).toJson, + "fileSystemSpace" -> SystemStats.fileSystemSpace.map { f => + Map( + "path" -> f.path.toJson, + "freeSpace" -> f.freeSpace.toJson, + "totalSpace" -> f.totalSpace.toJson, + "usableSpace" -> f.usableSpace.toJson) + }.toJson, + "operatingSystem" -> SystemStats.operatingSystemStats.toJson + )) + } + } + + /** + * Initializes services + */ + protected def activateServices(services: Seq[Module]): Unit = { + services.foreach { service => + Console.print(s"Service ${service.name} starts ...") + try { + service.activate() + } catch { + case t: Throwable => + log.error(s"Service ${service.name} failed to activate", t) + Console.print(" Failed! (check log)") + } + Console.print(" Done\n") + } + } + + /** + * Schedules services to be deactivated on the app shutdown + */ + protected def scheduleServicesDeactivation(services: Seq[Module]): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + services.foreach { service => + Console.print(s"Service ${service.name} shutting down ...\n") + try { + service.deactivate() + } catch { + case t: Throwable => + log.error(s"Service ${service.name} failed to deactivate", t) + Console.print(" Failed! (check log)") + } + Console.print(s"Service ${service.name} is shut down\n") + } + } + }) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/app/init.scala b/jvm/src/main/scala/xyz/driver/core/app/init.scala new file mode 100644 index 0000000..119c91a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/app/init.scala @@ -0,0 +1,119 @@ +package xyz.driver.core.app + +import java.nio.file.{Files, Paths} +import java.util.concurrent.{Executor, Executors} + +import akka.actor.ActorSystem +import akka.stream.ActorMaterializer +import com.typesafe.config.{Config, ConfigFactory} +import com.typesafe.scalalogging.Logger +import org.slf4j.LoggerFactory +import xyz.driver.core.logging.MdcExecutionContext +import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} +import xyz.driver.tracing.{GoogleTracer, NoTracer, Tracer} + +import scala.concurrent.ExecutionContext +import scala.util.Try + +object init { + + type RequiredBuildInfo = { + val name: String + val version: String + val gitHeadCommit: scala.Option[String] + } + + case class ApplicationContext(config: Config, time: TimeProvider, log: Logger) + + /** NOTE: This needs to be the first that is run when application starts. + * Otherwise if another command causes the logger to be instantiated, + * it will default to logback.xml, and not honor this configuration + */ + def configureLogging(): Unit = { + scala.sys.env.get("JSON_LOGGING") match { + case Some("true") => + System.setProperty("logback.configurationFile", "deployed-logback.xml") + case _ => + System.setProperty("logback.configurationFile", "logback.xml") + } + } + + def getEnvironmentSpecificConfig(): Config = { + scala.sys.env.get("APPLICATION_CONFIG_TYPE") match { + case Some("deployed") => + ConfigFactory.load(this.getClass.getClassLoader, "deployed-application.conf") + case _ => + xyz.driver.core.config.loadDefaultConfig + } + } + + def configureTracer(actorSystem: ActorSystem, applicationContext: ApplicationContext): Tracer = { + + val serviceAccountKeyFile = + Paths.get(applicationContext.config.getString("tracing.google.serviceAccountKeyfile")) + + if (Files.exists(serviceAccountKeyFile)) { + val materializer = ActorMaterializer()(actorSystem) + new GoogleTracer( + projectId = applicationContext.config.getString("tracing.google.projectId"), + serviceAccountFile = serviceAccountKeyFile + )(actorSystem, materializer) + } else { + applicationContext.log.warn(s"Tracing file $serviceAccountKeyFile was not found, using NoTracer!") + NoTracer + } + } + + def serviceActorSystem(serviceName: String, executionContext: ExecutionContext, config: Config): ActorSystem = { + val actorSystem = + ActorSystem(s"$serviceName-actors", Option(config), Option.empty[ClassLoader], Option(executionContext)) + + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = Try(actorSystem.terminate()) + }) + + actorSystem + } + + def toMdcExecutionContext(executor: Executor) = + new MdcExecutionContext(ExecutionContext.fromExecutor(executor)) + + def newFixedMdcExecutionContext(capacity: Int): MdcExecutionContext = + toMdcExecutionContext(Executors.newFixedThreadPool(capacity)) + + def defaultApplicationContext(): ApplicationContext = { + val config = getEnvironmentSpecificConfig() + + val time = new SystemTimeProvider() + val log = Logger(LoggerFactory.getLogger(classOf[DriverApp])) + + ApplicationContext(config, time, log) + } + + def createDefaultApplication( + modules: Seq[Module], + buildInfo: RequiredBuildInfo, + actorSystem: ActorSystem, + tracer: Tracer, + context: ApplicationContext): DriverApp = { + val scheme = context.config.getString("application.scheme") + val baseUrl = context.config.getString("application.baseUrl") + val port = context.config.getInt("application.port") + + new DriverApp( + buildInfo.name, + buildInfo.version, + buildInfo.gitHeadCommit.getOrElse("None"), + modules = modules, + context.time, + context.log, + context.config, + interface = "0.0.0.0", + baseUrl, + scheme, + port, + tracer + )(actorSystem, actorSystem.dispatcher) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/app/module.scala b/jvm/src/main/scala/xyz/driver/core/app/module.scala new file mode 100644 index 0000000..7be38eb --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/app/module.scala @@ -0,0 +1,70 @@ +package xyz.driver.core.app + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.Directives.complete +import akka.http.scaladsl.server.{Route, RouteConcatenation} +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import xyz.driver.core.database.Database +import xyz.driver.core.rest.{DriverRoute, NoServiceDiscovery, SavingUsedServiceDiscovery, ServiceDiscovery} + +import scala.reflect.runtime.universe._ + +trait Module { + val name: String + def route: Route + def routeTypes: Seq[Type] + + val serviceDiscovery: ServiceDiscovery with SavingUsedServiceDiscovery = new NoServiceDiscovery() + + def activate(): Unit = {} + def deactivate(): Unit = {} +} + +class EmptyModule extends Module { + override val name: String = "Nothing" + + override def route: Route = complete(StatusCodes.OK) + override def routeTypes: Seq[Type] = Seq.empty[Type] +} + +class SimpleModule(override val name: String, theRoute: Route, routeType: Type) extends Module { + private val driverRoute: DriverRoute = new DriverRoute { + override def route: Route = theRoute + override val log: Logger = xyz.driver.core.logging.NoLogger + } + + override def route: Route = driverRoute.routeWithDefaults + override def routeTypes: Seq[Type] = Seq(routeType) +} + +trait SingleDatabaseModule { self: Module => + + val databaseName: String + val config: Config + + def database = Database.fromConfig(config, databaseName) + + override def deactivate(): Unit = { + try { + database.database.close() + } finally { + self.deactivate() + } + } +} + +/** + * Module implementation which may be used to compose multiple modules + * + * @param name more general name of the composite module, + * must be provided as there is no good way to automatically + * generalize the name from the composed modules' names + * @param modules modules to compose into a single one + */ +class CompositeModule(override val name: String, modules: Seq[Module]) extends Module with RouteConcatenation { + override def route: Route = RouteConcatenation.concat(modules.map(_.route): _*) + override def routeTypes: Seq[Type] = modules.flatMap(_.routeTypes) + override def activate(): Unit = modules.foreach(_.activate()) + override def deactivate(): Unit = modules.reverse.foreach(_.deactivate()) +} diff --git a/jvm/src/main/scala/xyz/driver/core/auth.scala b/jvm/src/main/scala/xyz/driver/core/auth.scala new file mode 100644 index 0000000..896bd89 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/auth.scala @@ -0,0 +1,43 @@ +package xyz.driver.core + +import xyz.driver.core.domain.Email +import xyz.driver.core.time.Time +import scalaz.Equal + +object auth { + + trait Permission + + final case class Role(id: Id[Role], name: Name[Role]) { + + def oneOf(roles: Role*): Boolean = roles.contains(this) + + def oneOf(roles: Set[Role]): Boolean = roles.contains(this) + } + + object Role { + implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) + } + + trait User { + def id: Id[User] + } + + final case class AuthToken(value: String) + + final case class AuthTokenUserInfo( + id: Id[User], + email: Email, + emailVerified: Boolean, + audience: String, + roles: Set[Role], + expirationTime: Time) + extends User + + final case class RefreshToken(value: String) + final case class PermissionsToken(value: String) + + final case class PasswordHash(value: String) + + final case class AuthCredentials(identifier: String, password: String) +} diff --git a/jvm/src/main/scala/xyz/driver/core/cache.scala b/jvm/src/main/scala/xyz/driver/core/cache.scala new file mode 100644 index 0000000..3500a2a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/cache.scala @@ -0,0 +1,110 @@ +package xyz.driver.core + +import java.util.concurrent.{Callable, TimeUnit} + +import com.google.common.cache.{CacheBuilder, Cache => GuavaCache} +import com.typesafe.scalalogging.Logger + +import scala.concurrent.duration.{Duration, _} +import scala.concurrent.{ExecutionContext, Future} + +object cache { + + /** + * FutureCache is used to represent an in-memory, in-process, asynchronous cache. + * + * Every cache operation is atomic. + * + * This implementation evicts failed results, + * and doesn't interrupt the underlying request that has been fired off. + */ + class AsyncCache[K, V](name: String, cache: GuavaCache[K, Future[V]])(implicit executionContext: ExecutionContext) { + + private[this] val log = Logger(s"AsyncCache.$name") + private[this] val underlying = cache.asMap() + + private[this] def evictOnFailure(key: K, f: Future[V]): Future[V] = { + f.failed foreach { + case ex: Throwable => + log.debug(s"Evict key $key due to exception $ex") + evict(key, f) + } + f // we return the original future to make evict(k, f) easier to work with. + } + + /** + * Equivalent to getOrElseUpdate + */ + def apply(key: K)(value: => Future[V]): Future[V] = getOrElseUpdate(key)(value) + + /** + * Gets the cached Future. + * + * @return None if a value hasn't been specified for that key yet + * Some(ksync computation) if the value has been specified. Just + * because this returns Some(..) doesn't mean that it has been + * satisfied, but if it hasn't been satisfied, it's probably + * in-flight. + */ + def get(key: K): Option[Future[V]] = Option(underlying.get(key)) + + /** + * Gets the cached Future, or if it hasn't been returned yet, computes it and + * returns that value. + */ + def getOrElseUpdate(key: K)(compute: => Future[V]): Future[V] = { + log.debug(s"Try to retrieve key $key from cache") + evictOnFailure(key, cache.get(key, new Callable[Future[V]] { + def call(): Future[V] = { + log.debug(s"Cache miss, load the key: $key") + compute + } + })) + } + + /** + * Unconditionally sets a value for a given key + */ + def set(key: K, value: Future[V]): Unit = { + cache.put(key, value) + evictOnFailure(key, value) + } + + /** + * Evicts the contents of a `key` if the old value is `value`. + * + * Since `scala.concurrent.Future` uses reference equality, you must use the + * same object reference to evict a value. + * + * @return true if the key was evicted + * false if the key was not evicted + */ + def evict(key: K, value: Future[V]): Boolean = underlying.remove(key, value) + + /** + * @return the number of results that have been computed successfully or are in flight. + */ + def size: Int = cache.size.toInt + } + + object AsyncCache { + val DEFAULT_CAPACITY: Long = 10000L + val DEFAULT_READ_EXPIRATION: Duration = 10 minutes + val DEFAULT_WRITE_EXPIRATION: Duration = 1 hour + + def apply[K <: AnyRef, V <: AnyRef]( + name: String, + capacity: Long = DEFAULT_CAPACITY, + readExpiration: Duration = DEFAULT_READ_EXPIRATION, + writeExpiration: Duration = DEFAULT_WRITE_EXPIRATION)( + implicit executionContext: ExecutionContext): AsyncCache[K, V] = { + val guavaCache = CacheBuilder + .newBuilder() + .maximumSize(capacity) + .expireAfterAccess(readExpiration.toSeconds, TimeUnit.SECONDS) + .expireAfterWrite(writeExpiration.toSeconds, TimeUnit.SECONDS) + .build[K, Future[V]]() + new AsyncCache(name, guavaCache) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/config.scala b/jvm/src/main/scala/xyz/driver/core/config.scala new file mode 100644 index 0000000..be81408 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/config.scala @@ -0,0 +1,24 @@ +package xyz.driver.core + +import java.io.File +import com.typesafe.config.{Config, ConfigFactory} + +object config { + + def loadDefaultConfig: Config = { + val configDefaults = ConfigFactory.load(this.getClass.getClassLoader, "application.conf") + + scala.sys.env.get("APPLICATION_CONFIG").orElse(scala.sys.props.get("application.config")) match { + + case Some(filename) => + val configFile = new File(filename) + if (configFile.exists()) { + ConfigFactory.parseFile(configFile).withFallback(configDefaults) + } else { + throw new IllegalStateException(s"No config found at $filename") + } + + case None => configDefaults + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/core.scala b/jvm/src/main/scala/xyz/driver/core/core.scala new file mode 100644 index 0000000..72237b9 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/core.scala @@ -0,0 +1,128 @@ +package xyz.driver + +import scalaz.{Equal, Monad, OptionT} +import eu.timepit.refined.api.{Refined, Validate} +import eu.timepit.refined.collection.NonEmpty +import xyz.driver.core.rest.errors.ExternalServiceException + +import scala.concurrent.{ExecutionContext, Future} + +package object core { + + import scala.language.reflectiveCalls + + def make[T](v: => T)(f: T => Unit): T = { + val value = v + f(value) + value + } + + def using[R <: { def close() }, P](r: => R)(f: R => P): P = { + val resource = r + try { + f(resource) + } finally { + resource.close() + } + } + + object tagging { + private[core] trait Tagged[+V, +Tag] + + implicit class Taggable[V <: Any](val v: V) extends AnyVal { + def tagged[Tag]: V @@ Tag = v.asInstanceOf[V @@ Tag] + } + } + type @@[+V, +Tag] = V with tagging.Tagged[V, Tag] + + implicit class OptionTExtensions[H[_]: Monad, T](optionTValue: OptionT[H, T]) { + + def returnUnit: H[Unit] = optionTValue.fold[Unit](_ => (), ()) + + def continueIgnoringNone: OptionT[H, Unit] = + optionTValue.map(_ => ()).orElse(OptionT.some[H, Unit](())) + + def subflatMap[B](f: T => Option[B]): OptionT[H, B] = + OptionT.optionT[H](implicitly[Monad[H]].map(optionTValue.run)(_.flatMap(f))) + } + + implicit class MonadicExtensions[H[_]: Monad, T](monadicValue: H[T]) { + private implicit val monadT = implicitly[Monad[H]] + + def returnUnit: H[Unit] = monadT(monadicValue)(_ => ()) + + def toOptionT: OptionT[H, T] = + OptionT.optionT[H](monadT(monadicValue)(value => Option(value))) + + def toUnitOptionT: OptionT[H, Unit] = + OptionT.optionT[H](monadT(monadicValue)(_ => Option(()))) + } + + implicit class FutureExtensions[T](future: Future[T]) { + def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = + future.transform(identity, { + case ExternalServiceException(_, _, Some(e)) => e + case t: Throwable => t + }) + } +} + +package core { + + final case class Id[+Tag](value: String) extends AnyVal { + @inline def length: Int = value.length + override def toString: String = value + } + + @SuppressWarnings(Array("org.wartremover.warts.ImplicitConversion")) + object Id { + implicit def idEqual[T]: Equal[Id[T]] = Equal.equal[Id[T]](_ == _) + implicit def idOrdering[T]: Ordering[Id[T]] = Ordering.by[Id[T], String](_.value) + + sealed class Mapper[E, R] { + def apply[T >: E](id: Id[R]): Id[T] = Id[E](id.value) + def apply[T >: R](id: Id[E])(implicit dummy: DummyImplicit): Id[T] = Id[R](id.value) + } + object Mapper { + def apply[E, R] = new Mapper[E, R] + } + implicit def convertRE[R, E](id: Id[R])(implicit mapper: Mapper[E, R]): Id[E] = mapper[E](id) + implicit def convertER[E, R](id: Id[E])(implicit mapper: Mapper[E, R]): Id[R] = mapper[R](id) + } + + final case class Name[+Tag](value: String) extends AnyVal { + @inline def length: Int = value.length + override def toString: String = value + } + + object Name { + implicit def nameEqual[T]: Equal[Name[T]] = Equal.equal[Name[T]](_ == _) + implicit def nameOrdering[T]: Ordering[Name[T]] = Ordering.by(_.value) + + implicit def nameValidator[T, P](implicit stringValidate: Validate[String, P]): Validate[Name[T], P] = { + Validate.instance[Name[T], P, stringValidate.R]( + name => stringValidate.validate(name.value), + name => stringValidate.showExpr(name.value)) + } + } + + final case class NonEmptyName[+Tag](value: String Refined NonEmpty) { + @inline def length: Int = value.value.length + override def toString: String = value.value + } + + object NonEmptyName { + implicit def nonEmptyNameEqual[T]: Equal[NonEmptyName[T]] = + Equal.equal[NonEmptyName[T]](_.value.value == _.value.value) + + implicit def nonEmptyNameOrdering[T]: Ordering[NonEmptyName[T]] = Ordering.by(_.value.value) + } + + final case class Revision[T](id: String) + + object Revision { + implicit def revisionEqual[T]: Equal[Revision[T]] = Equal.equal[Revision[T]](_.id == _.id) + } + + final case class Base64(value: String) +} diff --git a/jvm/src/main/scala/xyz/driver/core/database/Converters.scala b/jvm/src/main/scala/xyz/driver/core/database/Converters.scala new file mode 100644 index 0000000..ad79abf --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/Converters.scala @@ -0,0 +1,26 @@ +package xyz.driver.core.database + +import xyz.driver.core.rest.errors.DatabaseException + +import scala.reflect.ClassTag + +/** + * Helper methods for converting between table rows and Scala objects + */ +trait Converters { + def fromStringOrThrow[ADT](entityStr: String, mapper: (String => Option[ADT]), entityName: String): ADT = + mapper(entityStr).getOrElse(throw DatabaseException(s"Invalid $entityName in database: $entityStr")) + + def expectValid[ADT](mapper: String => Option[ADT], query: String)(implicit ct: ClassTag[ADT]): ADT = + fromStringOrThrow[ADT](query, mapper, ct.toString()) + + def expectExistsAndValid[ADT](mapper: String => Option[ADT], query: Option[String], contextMsg: String = "")( + implicit ct: ClassTag[ADT]): ADT = { + expectValid[ADT](mapper, query.getOrElse(throw DatabaseException(contextMsg))) + } + + def expectValidOrEmpty[ADT](mapper: String => Option[ADT], query: Option[String], contextMsg: String = "")( + implicit ct: ClassTag[ADT]): Option[ADT] = { + query.map(expectValid[ADT](mapper, _)) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala b/jvm/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala new file mode 100644 index 0000000..5939efb --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/MdcAsyncExecutor.scala @@ -0,0 +1,53 @@ +/** Code ported from "de.geekonaut" %% "slickmdc" % "1.0.0" + * License: @see https://github.com/AVGP/slickmdc/blob/master/LICENSE + * Blog post: @see http://50linesofco.de/post/2016-07-01-slick-and-slf4j-mdc-logging-in-scala.html + */ +package xyz.driver.core +package database + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent._ +import com.typesafe.scalalogging.StrictLogging +import slick.util.AsyncExecutor + +import logging.MdcExecutionContext + +/** Taken from the original Slick AsyncExecutor and simplified + * @see https://github.com/slick/slick/blob/3.1/slick/src/main/scala/slick/util/AsyncExecutor.scala + */ +object MdcAsyncExecutor extends StrictLogging { + + /** Create an AsyncExecutor with a fixed-size thread pool. + * + * @param name The name for the thread pool. + * @param numThreads The number of threads in the pool. + */ + def apply(name: String, numThreads: Int): AsyncExecutor = { + new AsyncExecutor { + val tf = new DaemonThreadFactory(name + "-") + + lazy val executionContext = { + new MdcExecutionContext(ExecutionContext.fromExecutor(Executors.newFixedThreadPool(numThreads, tf))) + } + + def close(): Unit = {} + } + } + + def default(name: String = "AsyncExecutor.default"): AsyncExecutor = apply(name, 20) + + private class DaemonThreadFactory(namePrefix: String) extends ThreadFactory { + private[this] val group = + Option(System.getSecurityManager).fold(Thread.currentThread.getThreadGroup)(_.getThreadGroup) + private[this] val threadNumber = new AtomicInteger(1) + + def newThread(r: Runnable): Thread = { + val t = new Thread(group, r, namePrefix + threadNumber.getAndIncrement, 0) + if (!t.isDaemon) t.setDaemon(true) + if (t.getPriority != Thread.NORM_PRIORITY) t.setPriority(Thread.NORM_PRIORITY) + t + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/database/PatchedHsqldbProfile.scala b/jvm/src/main/scala/xyz/driver/core/database/PatchedHsqldbProfile.scala new file mode 100644 index 0000000..e2efd32 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/PatchedHsqldbProfile.scala @@ -0,0 +1,16 @@ +package xyz.driver.core.database + +import slick.jdbc.{HsqldbProfile, JdbcType} +import slick.ast.FieldSymbol +import slick.relational.RelationalProfile + +trait PatchedHsqldbProfile extends HsqldbProfile { + override def defaultSqlTypeName(tmd: JdbcType[_], sym: Option[FieldSymbol]): String = tmd.sqlType match { + case java.sql.Types.VARCHAR => + val size = sym.flatMap(_.findColumnOption[RelationalProfile.ColumnOption.Length]) + size.fold("LONGVARCHAR")(l => if (l.varying) s"VARCHAR(${l.length})" else s"CHAR(${l.length})") + case _ => super.defaultSqlTypeName(tmd, sym) + } +} + +object PatchedHsqldbProfile extends PatchedHsqldbProfile diff --git a/jvm/src/main/scala/xyz/driver/core/database/Repository.scala b/jvm/src/main/scala/xyz/driver/core/database/Repository.scala new file mode 100644 index 0000000..31c79ad --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/Repository.scala @@ -0,0 +1,73 @@ +package xyz.driver.core.database + +import scalaz.std.scalaFuture._ +import scalaz.{ListT, Monad, OptionT} +import slick.lifted.{AbstractTable, CanBeQueryCondition, RunnableCompiled} +import slick.{lifted => sl} + +import scala.concurrent.{ExecutionContext, Future} + +trait Repository { + type T[D] + implicit def monadT: Monad[T] + + def execute[D](operations: T[D]): Future[D] + def noAction[V](v: V): T[V] + def customAction[R](action: => Future[R]): T[R] + + def customAction[R](action: => OptionT[Future, R]): OptionT[T, R] = + OptionT[T, R](customAction(action.run)) +} + +class FutureRepository(executionContext: ExecutionContext) extends Repository { + implicit val exec: ExecutionContext = executionContext + override type T[D] = Future[D] + implicit val monadT: Monad[Future] = implicitly[Monad[Future]] + + def execute[D](operations: T[D]): Future[D] = operations + def noAction[V](v: V): T[V] = Future.successful(v) + def customAction[R](action: => Future[R]): T[R] = action +} + +class SlickRepository(database: Database, executionContext: ExecutionContext) extends Repository { + import database.profile.api._ + implicit val exec: ExecutionContext = executionContext + + override type T[D] = slick.dbio.DBIO[D] + + implicit protected class QueryOps[+E, U](query: Query[E, U, Seq]) { + def resultT: ListT[T, U] = ListT[T, U](query.result.map(_.toList)) + + def maybeFilter[V, R: CanBeQueryCondition](data: Option[V])(f: V => E => R): sl.Query[E, U, Seq] = + data.map(v => query.withFilter(f(v))).getOrElse(query) + } + + implicit protected class CompiledQueryOps[U](compiledQuery: RunnableCompiled[_, Seq[U]]) { + def resultT: ListT[T, U] = ListT.listT[T](compiledQuery.result.map(_.toList)) + } + + private val dbioMonad = new Monad[T] { + override def point[A](a: => A): T[A] = DBIO.successful(a) + + override def bind[A, B](fa: T[A])(f: A => T[B]): T[B] = fa.flatMap(f) + } + + override implicit def monadT: Monad[T] = dbioMonad + + override def execute[D](readOperations: T[D]): Future[D] = { + database.database.run(readOperations.transactionally) + } + + override def noAction[V](v: V): T[V] = DBIO.successful(v) + + override def customAction[R](action: => Future[R]): T[R] = DBIO.from(action) + + def affectsRows(updatesCount: Int): Option[Unit] = { + if (updatesCount > 0) Some(()) else None + } + + def insertReturning[AT <: AbstractTable[_], V](table: TableQuery[AT])( + row: AT#TableElementType): slick.dbio.DBIO[AT#TableElementType] = { + table.returning(table) += row + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/database/SlickGetResultSupport.scala b/jvm/src/main/scala/xyz/driver/core/database/SlickGetResultSupport.scala new file mode 100644 index 0000000..8293371 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/SlickGetResultSupport.scala @@ -0,0 +1,30 @@ +package xyz.driver.core.database + +import slick.jdbc.GetResult +import xyz.driver.core.date.Date +import xyz.driver.core.time.Time +import xyz.driver.core.{Id, Name} + +trait SlickGetResultSupport { + implicit def GetId[U]: GetResult[Id[U]] = + GetResult(r => Id[U](r.nextString())) + implicit def GetIdOption[U]: GetResult[Option[Id[U]]] = + GetResult(_.nextStringOption().map(Id.apply[U])) + + implicit def GetName[U]: GetResult[Name[U]] = + GetResult(r => Name[U](r.nextString())) + implicit def GetNameOption[U]: GetResult[Option[Name[U]]] = + GetResult(_.nextStringOption().map(Name.apply[U])) + + implicit val GetTime: GetResult[Time] = + GetResult(r => Time(r.nextTimestamp.getTime)) + implicit val GetTimeOption: GetResult[Option[Time]] = + GetResult(_.nextTimestampOption().map(t => Time(t.getTime))) + + implicit val GetDate: GetResult[Date] = + GetResult(r => sqlDateToDate(r.nextDate())) + implicit val GetDateOption: GetResult[Option[Date]] = + GetResult(_.nextDateOption().map(sqlDateToDate)) +} + +object SlickGetResultSupport extends SlickGetResultSupport diff --git a/jvm/src/main/scala/xyz/driver/core/database/database.scala b/jvm/src/main/scala/xyz/driver/core/database/database.scala new file mode 100644 index 0000000..ae06517 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/database.scala @@ -0,0 +1,165 @@ +package xyz.driver.core + +import slick.basic.DatabaseConfig +import slick.jdbc.JdbcProfile +import xyz.driver.core.date.Date +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import com.typesafe.config.Config + +package database { + + import java.sql.SQLDataException + + import eu.timepit.refined.api.{Refined, Validate} + import eu.timepit.refined.refineV + + trait Database { + val profile: JdbcProfile + val database: JdbcProfile#Backend#Database + } + + object Database { + def fromConfig(config: Config, databaseName: String): Database = { + val dbConfig: DatabaseConfig[JdbcProfile] = DatabaseConfig.forConfig(databaseName, config) + + new Database { + val profile: JdbcProfile = dbConfig.profile + val database: JdbcProfile#Backend#Database = dbConfig.db + } + } + + def fromConfig(databaseName: String): Database = { + fromConfig(com.typesafe.config.ConfigFactory.load(), databaseName) + } + } + + trait ColumnTypes { + val profile: JdbcProfile + } + + trait NameColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.Name.columnType`[T]: BaseColumnType[Name[T]] + } + + object NameColumnTypes { + trait StringName extends NameColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.Name.columnType`[T]: BaseColumnType[Name[T]] = + MappedColumnType.base[Name[T], String](_.value, Name[T]) + } + } + + trait DateColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.time.Date.columnType`: BaseColumnType[Date] + } + + object DateColumnTypes { + trait SqlDate extends DateColumnTypes { + import profile.api._ + override implicit def `xyz.driver.core.time.Date.columnType`: BaseColumnType[Date] = + MappedColumnType.base[Date, java.sql.Date](dateToSqlDate, sqlDateToDate) + } + } + + trait RefinedColumnTypes[T, Predicate] extends ColumnTypes { + import profile.api._ + implicit def `eu.timepit.refined.api.Refined`( + implicit columnType: BaseColumnType[T], + validate: Validate[T, Predicate]): BaseColumnType[T Refined Predicate] + } + + object RefinedColumnTypes { + trait RefinedValue[T, Predicate] extends RefinedColumnTypes[T, Predicate] { + import profile.api._ + override implicit def `eu.timepit.refined.api.Refined`( + implicit columnType: BaseColumnType[T], + validate: Validate[T, Predicate]): BaseColumnType[T Refined Predicate] = + MappedColumnType.base[T Refined Predicate, T]( + _.value, { dbValue => + refineV[Predicate](dbValue) match { + case Left(refinementError) => + throw new SQLDataException( + s"Value in the database doesn't match the refinement constraints: $refinementError") + case Right(refinedValue) => + refinedValue + } + } + ) + } + } + + trait IdColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.Id.columnType`[T]: BaseColumnType[Id[T]] + } + + object IdColumnTypes { + trait UUID extends IdColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.Id.columnType`[T] = + MappedColumnType + .base[Id[T], java.util.UUID](id => java.util.UUID.fromString(id.value), uuid => Id[T](uuid.toString)) + } + trait SerialId extends IdColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.Id.columnType`[T] = + MappedColumnType.base[Id[T], Long](_.value.toLong, serialId => Id[T](serialId.toString)) + } + trait NaturalId extends IdColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.Id.columnType`[T] = + MappedColumnType.base[Id[T], String](_.value, Id[T]) + } + } + + trait TimestampColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] + } + + object TimestampColumnTypes { + trait SqlTimestamp extends TimestampColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] = + MappedColumnType.base[Time, java.sql.Timestamp]( + time => new java.sql.Timestamp(time.millis), + timestamp => Time(timestamp.getTime)) + } + + trait PrimitiveTimestamp extends TimestampColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] = + MappedColumnType.base[Time, Long](_.millis, Time(_)) + } + } + + trait KeyMappers extends ColumnTypes { + import profile.api._ + + def uuidKeyMapper[T] = + MappedColumnType + .base[Id[T], java.util.UUID](id => java.util.UUID.fromString(id.value), uuid => Id[T](uuid.toString)) + def serialKeyMapper[T] = MappedColumnType.base[Id[T], Long](_.value.toLong, serialId => Id[T](serialId.toString)) + def naturalKeyMapper[T] = MappedColumnType.base[Id[T], String](_.value, Id[T]) + } + + trait DatabaseObject extends ColumnTypes { + def createTables(): Future[Unit] + def disconnect(): Unit + } + + abstract class DatabaseObjectAdapter extends DatabaseObject { + def createTables(): Future[Unit] = Future.successful(()) + def disconnect(): Unit = {} + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/database/package.scala b/jvm/src/main/scala/xyz/driver/core/database/package.scala new file mode 100644 index 0000000..aee14c6 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/database/package.scala @@ -0,0 +1,61 @@ +package xyz.driver.core + +import java.sql.{Date => SqlDate} +import java.util.Calendar + +import date.{Date, Month} +import slick.dbio._ +import slick.jdbc.JdbcProfile +import slick.relational.RelationalProfile + +package object database { + + type Schema = { + def create: DBIOAction[Unit, NoStream, Effect.Schema] + def drop: DBIOAction[Unit, NoStream, Effect.Schema] + } + + @deprecated( + "sbt-slick-codegen 0.11.0+ no longer needs to generate these methods. Please use the new `CodegenTables` trait when upgrading.", + "driver-core 1.8.12") + type GeneratedTables = { + // structure of Slick data model traits generated by sbt-slick-codegen + val profile: JdbcProfile + def schema: profile.SchemaDescription + + def createNamespaceSchema: StreamingDBIO[Vector[Unit], Unit] + def dropNamespaceSchema: StreamingDBIO[Vector[Unit], Unit] + } + + /** A structural type for schema traits generated by sbt-slick-codegen. + * This will compile with codegen versions before 0.11.0, but note + * that methods in [[GeneratedTables]] are no longer generated. + */ + type CodegenTables[Profile <: RelationalProfile] = { + val profile: Profile + def schema: profile.SchemaDescription + } + + private[database] def sqlDateToDate(sqlDate: SqlDate): Date = { + // NOTE: SQL date does not have a time component, so this date + // should only be interpreted in the running JVMs timezone. + val cal = Calendar.getInstance() + cal.setTime(sqlDate) + Date(cal.get(Calendar.YEAR), Month(cal.get(Calendar.MONTH)), cal.get(Calendar.DAY_OF_MONTH)) + } + + private[database] def dateToSqlDate(date: Date): SqlDate = { + val cal = Calendar.getInstance() + cal.set(date.year, date.month, date.day, 0, 0, 0) + new SqlDate(cal.getTime.getTime) + } + + @deprecated("Dal is deprecated. Please use Repository trait instead!", "1.8.26") + type Dal = Repository + + @deprecated("SlickDal is deprecated. Please use SlickRepository class instead!", "1.8.26") + type SlickDal = SlickRepository + + @deprecated("FutureDal is deprecated. Please use FutureRepository class instead!", "1.8.26") + type FutureDal = FutureRepository +} diff --git a/jvm/src/main/scala/xyz/driver/core/date.scala b/jvm/src/main/scala/xyz/driver/core/date.scala new file mode 100644 index 0000000..5454093 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/date.scala @@ -0,0 +1,109 @@ +package xyz.driver.core + +import java.util.Calendar + +import enumeratum._ +import scalaz.std.anyVal._ +import scalaz.syntax.equal._ + +import scala.collection.immutable.IndexedSeq +import scala.util.Try + +/** + * Driver Date type and related validators/extractors. + * Day, Month, and Year extractors are from ISO 8601 strings => driver...Date integers. + * TODO: Decouple extractors from ISO 8601, as we might want to parse other formats. + */ +object date { + + sealed trait DayOfWeek extends EnumEntry + object DayOfWeek extends Enum[DayOfWeek] { + case object Monday extends DayOfWeek + case object Tuesday extends DayOfWeek + case object Wednesday extends DayOfWeek + case object Thursday extends DayOfWeek + case object Friday extends DayOfWeek + case object Saturday extends DayOfWeek + case object Sunday extends DayOfWeek + + val values: IndexedSeq[DayOfWeek] = findValues + + val All: Set[DayOfWeek] = values.toSet + + def fromString(day: String): Option[DayOfWeek] = withNameInsensitiveOption(day) + } + + type Day = Int @@ Day.type + + object Day { + def apply(value: Int): Day = { + require(1 to 31 contains value, "Day must be in range 1 <= value <= 31") + value.asInstanceOf[Day] + } + + def unapply(dayString: String): Option[Int] = { + require(dayString.length === 2, s"ISO 8601 day string, DD, must have length 2: $dayString") + Try(dayString.toInt).toOption.map(apply) + } + } + + type Month = Int @@ Month.type + + object Month { + def apply(value: Int): Month = { + require(0 to 11 contains value, "Month is zero-indexed: 0 <= value <= 11") + value.asInstanceOf[Month] + } + val JANUARY = Month(Calendar.JANUARY) + val FEBRUARY = Month(Calendar.FEBRUARY) + val MARCH = Month(Calendar.MARCH) + val APRIL = Month(Calendar.APRIL) + val MAY = Month(Calendar.MAY) + val JUNE = Month(Calendar.JUNE) + val JULY = Month(Calendar.JULY) + val AUGUST = Month(Calendar.AUGUST) + val SEPTEMBER = Month(Calendar.SEPTEMBER) + val OCTOBER = Month(Calendar.OCTOBER) + val NOVEMBER = Month(Calendar.NOVEMBER) + val DECEMBER = Month(Calendar.DECEMBER) + + def unapply(monthString: String): Option[Month] = { + require(monthString.length === 2, s"ISO 8601 month string, MM, must have length 2: $monthString") + Try(monthString.toInt).toOption.map(isoM => apply(isoM - 1)) + } + } + + type Year = Int @@ Year.type + + object Year { + def apply(value: Int): Year = value.asInstanceOf[Year] + + def unapply(yearString: String): Option[Int] = { + require(yearString.length === 4, s"ISO 8601 year string, YYYY, must have length 4: $yearString") + Try(yearString.toInt).toOption.map(apply) + } + } + + final case class Date(year: Int, month: Month, day: Int) { + override def toString = f"$year%04d-${month + 1}%02d-$day%02d" + } + + object Date { + implicit def dateOrdering: Ordering[Date] = Ordering.fromLessThan { (date1, date2) => + if (date1.year != date2.year) { + date1.year < date2.year + } else if (date1.month != date2.month) { + date1.month < date2.month + } else { + date1.day < date2.day + } + } + + def fromString(dateString: String): Option[Date] = { + dateString.split('-') match { + case Array(Year(year), Month(month), Day(day)) => Some(Date(year, month, day)) + case _ => None + } + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/domain.scala b/jvm/src/main/scala/xyz/driver/core/domain.scala new file mode 100644 index 0000000..fa3b5c4 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/domain.scala @@ -0,0 +1,46 @@ +package xyz.driver.core + +import com.google.i18n.phonenumbers.PhoneNumberUtil +import scalaz.Equal +import scalaz.std.string._ +import scalaz.syntax.equal._ + +object domain { + + final case class Email(username: String, domain: String) { + override def toString: String = username + "@" + domain + } + + object Email { + implicit val emailEqual: Equal[Email] = Equal.equal { + case (left, right) => left.toString.toLowerCase === right.toString.toLowerCase + } + + def parse(emailString: String): Option[Email] = { + Some(emailString.split("@")) collect { + case Array(username, domain) => Email(username, domain) + } + } + } + + final case class PhoneNumber(countryCode: String = "1", number: String) { + override def toString: String = s"+$countryCode $number" + } + + object PhoneNumber { + + private val phoneUtil = PhoneNumberUtil.getInstance() + + def parse(phoneNumber: String): Option[PhoneNumber] = { + val phone = scala.util.Try(phoneUtil.parseAndKeepRawInput(phoneNumber, "US")).toOption + + val validated = phone match { + case None => None + case Some(pn) => + if (!phoneUtil.isValidNumber(pn)) None + else Some(pn) + } + validated.map(pn => PhoneNumber(pn.getCountryCode.toString, pn.getNationalNumber.toString)) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala b/jvm/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala new file mode 100644 index 0000000..ce26fe4 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/file/FileSystemStorage.scala @@ -0,0 +1,76 @@ +package xyz.driver.core.file + +import akka.NotUsed +import akka.stream.scaladsl.{FileIO, Source} +import akka.util.ByteString +import java.io.File +import java.nio.file.{Files, Path, Paths} + +import xyz.driver.core.{Name, Revision} +import xyz.driver.core.time.Time + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +@deprecated("Consider using xyz.driver.core.storage.FileSystemBlobStorage instead", "driver-core 1.8.14") +class FileSystemStorage(executionContext: ExecutionContext) extends FileStorage { + implicit private val execution = executionContext + + override def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val destinationFile = destination.toFile + + if (destinationFile.getParentFile.exists() || destinationFile.getParentFile.mkdirs()) { + if (localSource.renameTo(destinationFile)) () + else { + throw new Exception( + s"Failed to move file from `${localSource.getCanonicalPath}` to `${destinationFile.getCanonicalPath}`") + } + } else { + throw new Exception(s"Failed to create parent directories for file `${destinationFile.getCanonicalPath}`") + } + } + } + + override def download(filePath: Path): OptionT[Future, File] = + OptionT.optionT(Future { + Option(new File(filePath.toString)).filter(file => file.exists() && file.isFile) + }) + + override def stream(filePath: Path): OptionT[Future, Source[ByteString, NotUsed]] = + OptionT.optionT(Future { + if (Files.exists(filePath)) { + Some(FileIO.fromPath(filePath).mapMaterializedValue(_ => NotUsed)) + } else { + None + } + }) + + override def delete(filePath: Path): Future[Unit] = Future { + val file = new File(filePath.toString) + if (file.delete()) () + else { + throw new Exception(s"Failed to delete file $file" + (if (!file.exists()) ", file does not exist." else ".")) + } + } + + override def list(path: Path): ListT[Future, FileLink] = + ListT.listT(Future { + val file = new File(path.toString) + if (file.isDirectory) { + file.listFiles().toList.filter(_.isFile).map { file => + FileLink( + Name[File](file.getName), + Paths.get(file.getPath), + Revision[File](file.hashCode.toString), + Time(file.lastModified()), + file.length()) + } + } else List.empty[FileLink] + }) + + override def exists(path: Path): Future[Boolean] = Future { + Files.exists(path) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/file/GcsStorage.scala b/jvm/src/main/scala/xyz/driver/core/file/GcsStorage.scala new file mode 100644 index 0000000..5c94645 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/file/GcsStorage.scala @@ -0,0 +1,135 @@ +package xyz.driver.core.file + +import akka.NotUsed +import akka.stream.scaladsl.Source +import akka.util.ByteString +import com.google.cloud.ReadChannel +import java.io.{BufferedOutputStream, File, FileInputStream, FileOutputStream} +import java.net.URL +import java.nio.ByteBuffer +import java.nio.file.{Path, Paths} +import java.util.concurrent.TimeUnit + +import com.google.cloud.storage.Storage.BlobListOption +import com.google.cloud.storage.{Option => _, _} +import xyz.driver.core.time.Time +import xyz.driver.core.{Name, Revision, generators} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +@deprecated("Consider using xyz.driver.core.storage.GcsBlobStorage instead", "driver-core 1.8.14") +class GcsStorage( + storageClient: Storage, + bucketName: Name[Bucket], + executionContext: ExecutionContext, + chunkSize: Int = 4096) + extends SignedFileStorage { + implicit private val execution: ExecutionContext = executionContext + + override def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val blobId = BlobId.of(bucketName.value, destination.toString) + def acl = Bucket.BlobWriteOption.predefinedAcl(Storage.PredefinedAcl.PUBLIC_READ) + + storageClient.get(bucketName.value).create(blobId.getName, new FileInputStream(localSource), acl) + } + } + + override def download(filePath: Path): OptionT[Future, File] = { + OptionT.optionT(Future { + Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map { + blob => + val tempDir = System.getProperty("java.io.tmpdir") + val randomFolderName = generators.nextUuid().toString + val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString) + + if (!tempDestinationFile.getParentFile.mkdirs()) { + throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`") + } else { + val target = new BufferedOutputStream(new FileOutputStream(tempDestinationFile)) + try target.write(blob.getContent()) + finally target.close() + tempDestinationFile + } + } + }) + } + + override def stream(filePath: Path): OptionT[Future, Source[ByteString, NotUsed]] = + OptionT.optionT(Future { + def readChunk(rc: ReadChannel): Option[ByteString] = { + val buffer = ByteBuffer.allocate(chunkSize) + val length = rc.read(buffer) + if (length > 0) { + buffer.flip() + Some(ByteString.fromByteBuffer(buffer)) + } else { + None + } + } + + Option(storageClient.get(bucketName.value, filePath.toString)).map { blob => + Source.unfoldResource[ByteString, ReadChannel]( + create = () => blob.reader(), + read = channel => readChunk(channel), + close = channel => channel.close() + ) + } + }) + + override def delete(filePath: Path): Future[Unit] = Future { + storageClient.delete(BlobId.of(bucketName.value, filePath.toString)) + } + + override def list(directoryPath: Path): ListT[Future, FileLink] = + ListT.listT(Future { + val directory = s"$directoryPath/" + val page = storageClient.list( + bucketName.value, + BlobListOption.currentDirectory(), + BlobListOption.prefix(directory) + ) + + page + .iterateAll() + .asScala + .filter(_.getName != directory) + .map(blobToFileLink(directoryPath, _)) + .toList + }) + + protected def blobToFileLink(path: Path, blob: Blob): FileLink = { + def nullError(property: String) = throw new IllegalStateException(s"Blob $blob at $path does not have $property") + val name = Option(blob.getName).getOrElse(nullError("a name")) + val generation = Option(blob.getGeneration).getOrElse(nullError("a generation")) + val updateTime = Option(blob.getUpdateTime).getOrElse(nullError("an update time")) + val size = Option(blob.getSize).getOrElse(nullError("a size")) + + FileLink( + Name(name.split('/').last), + Paths.get(name), + Revision(generation.toString), + Time(updateTime), + size + ) + } + + override def exists(path: Path): Future[Boolean] = Future { + val blob = Option( + storageClient.get( + bucketName.value, + path.toString + )) + blob.isDefined + } + + override def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL] = + OptionT.optionT(Future { + Option(storageClient.get(bucketName.value, filePath.toString)).filterNot(_.getSize == 0).map { blob => + blob.signUrl(duration.toSeconds, TimeUnit.SECONDS) + } + }) +} diff --git a/jvm/src/main/scala/xyz/driver/core/file/S3Storage.scala b/jvm/src/main/scala/xyz/driver/core/file/S3Storage.scala new file mode 100644 index 0000000..5158d4d --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/file/S3Storage.scala @@ -0,0 +1,87 @@ +package xyz.driver.core.file + +import akka.NotUsed +import akka.stream.scaladsl.{Source, StreamConverters} +import akka.util.ByteString +import java.io.File +import java.nio.file.{Path, Paths} +import java.util.UUID.randomUUID + +import com.amazonaws.services.s3.AmazonS3 +import com.amazonaws.services.s3.model.{Bucket, GetObjectRequest, ListObjectsV2Request} +import xyz.driver.core.{Name, Revision} +import xyz.driver.core.time.Time + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +@deprecated( + "Blob storage functionality has been reimplemented in xyz.driver.core.storage.BlobStorage. " + + "It has not been ported to S3 storage. Please raise an issue if this required for your use-case.", + "driver-core 1.8.14" +) +class S3Storage(s3: AmazonS3, bucket: Name[Bucket], executionContext: ExecutionContext, chunkSize: Int = 4096) + extends FileStorage { + implicit private val execution = executionContext + + override def upload(localSource: File, destination: Path): Future[Unit] = Future { + checkSafeFileName(destination) { + val _ = s3.putObject(bucket.value, destination.toString, localSource).getETag + } + } + + override def download(filePath: Path): OptionT[Future, File] = + OptionT.optionT(Future { + val tempDir = System.getProperty("java.io.tmpdir") + val randomFolderName = randomUUID().toString + val tempDestinationFile = new File(Paths.get(tempDir, randomFolderName, filePath.toString).toString) + + if (!tempDestinationFile.getParentFile.mkdirs()) { + throw new Exception(s"Failed to create temp directory to download file `$tempDestinationFile`") + } else { + Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString), tempDestinationFile)).map { _ => + tempDestinationFile + } + } + }) + + override def stream(filePath: Path): OptionT[Future, Source[ByteString, NotUsed]] = + OptionT.optionT(Future { + Option(s3.getObject(new GetObjectRequest(bucket.value, filePath.toString))).map { elem => + StreamConverters.fromInputStream(() => elem.getObjectContent(), chunkSize).mapMaterializedValue(_ => NotUsed) + } + }) + + override def delete(filePath: Path): Future[Unit] = Future { + s3.deleteObject(bucket.value, filePath.toString) + } + + override def list(path: Path): ListT[Future, FileLink] = + ListT.listT(Future { + import scala.collection.JavaConverters._ + val req = new ListObjectsV2Request().withBucketName(bucket.value).withPrefix(path.toString).withMaxKeys(2) + + def isInSubFolder(path: Path)(fileLink: FileLink) = + fileLink.location.toString.replace(path.toString + "/", "").contains("/") + + Iterator.continually(s3.listObjectsV2(req)).takeWhile { result => + req.setContinuationToken(result.getNextContinuationToken) + result.isTruncated + } flatMap { result => + result.getObjectSummaries.asScala.toList.map { summary => + FileLink( + Name[File](summary.getKey), + Paths.get(path.toString + "/" + summary.getKey), + Revision[File](summary.getETag), + Time(summary.getLastModified.getTime), + summary.getSize + ) + } filterNot isInSubFolder(path) + } toList + }) + + override def exists(path: Path): Future[Boolean] = Future { + s3.doesObjectExist(bucket.value, path.toString) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/file/package.scala b/jvm/src/main/scala/xyz/driver/core/file/package.scala new file mode 100644 index 0000000..58955e5 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/file/package.scala @@ -0,0 +1,68 @@ +package xyz.driver.core + +import java.io.File +import java.nio.file.Path + +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.{ListT, OptionT} + +package file { + + import akka.NotUsed + import akka.stream.scaladsl.Source + import akka.util.ByteString + import java.net.URL + + import scala.concurrent.duration.Duration + + final case class FileLink( + name: Name[File], + location: Path, + revision: Revision[File], + lastModificationDate: Time, + fileSize: Long + ) + + trait FileService { + + def getFileLink(id: Name[File]): FileLink + + def getFile(fileLink: FileLink): File + } + + trait FileStorage { + + def upload(localSource: File, destination: Path): Future[Unit] + + def download(filePath: Path): OptionT[Future, File] + + def stream(filePath: Path): OptionT[Future, Source[ByteString, NotUsed]] + + def delete(filePath: Path): Future[Unit] + + /** List contents of a directory */ + def list(directoryPath: Path): ListT[Future, FileLink] + + def exists(path: Path): Future[Boolean] + + /** List of characters to avoid in S3 (I would say file names in general) + * + * @see http://stackoverflow.com/questions/7116450/what-are-valid-s3-key-names-that-can-be-accessed-via-the-s3-rest-api + */ + private val illegalChars = "\\^`><{}][#%~|&@:,$=+?; " + + protected def checkSafeFileName[T](filePath: Path)(f: => T): T = { + filePath.toString.find(c => illegalChars.contains(c)) match { + case Some(illegalCharacter) => + throw new IllegalArgumentException(s"File name cannot contain character `$illegalCharacter`") + case None => f + } + } + } + + trait SignedFileStorage extends FileStorage { + def signedFileUrl(filePath: Path, duration: Duration): OptionT[Future, URL] + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/future.scala b/jvm/src/main/scala/xyz/driver/core/future.scala new file mode 100644 index 0000000..1ee3576 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/future.scala @@ -0,0 +1,87 @@ +package xyz.driver.core + +import com.typesafe.scalalogging.Logger + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.{Failure, Success, Try} + +object future { + val log = Logger("Driver.Future") + + implicit class RichFuture[T](f: Future[T]) { + def mapAll[U](pf: PartialFunction[Try[T], U])(implicit executionContext: ExecutionContext): Future[U] = { + val p = Promise[U]() + f.onComplete(r => p.complete(Try(pf(r)))) + p.future + } + + def failFastZip[U](that: Future[U])(implicit executionContext: ExecutionContext): Future[(T, U)] = { + future.failFastZip(f, that) + } + } + + def failFastSequence[T](t: Iterable[Future[T]])(implicit ec: ExecutionContext): Future[Seq[T]] = { + t.foldLeft(Future.successful(Nil: List[T])) { (f, i) => + failFastZip(f, i).map { case (tail, h) => h :: tail } + } + .map(_.reverse) + } + + /** + * Standard scala zip waits forever on the left side, even if the right side fails + */ + def failFastZip[T, U](ft: Future[T], fu: Future[U])(implicit ec: ExecutionContext): Future[(T, U)] = { + type State = Either[(T, Promise[U]), (U, Promise[T])] + val middleState = Promise[State]() + + ft.onComplete { + case f @ Failure(err) => + if (!middleState.tryFailure(err)) { + // the right has already succeeded + middleState.future.foreach { + case Right((_, pt)) => pt.complete(f) + case Left((t1, _)) => // This should never happen + log.error(s"Logic error: tried to set Failure($err) but Left($t1) already set") + } + } + case Success(t) => + // Create the next promise: + val pu = Promise[U]() + if (!middleState.trySuccess(Left((t, pu)))) { + // we can't set, so the other promise beat us here. + middleState.future.foreach { + case Right((_, pt)) => pt.success(t) + case Left((t1, _)) => // This should never happen + log.error(s"Logic error: tried to set Left($t) but Left($t1) already set") + } + } + } + fu.onComplete { + case f @ Failure(err) => + if (!middleState.tryFailure(err)) { + // we can't set, so the other promise beat us here. + middleState.future.foreach { + case Left((_, pu)) => pu.complete(f) + case Right((u1, _)) => // This should never happen + log.error(s"Logic error: tried to set Failure($err) but Right($u1) already set") + } + } + case Success(u) => + // Create the next promise: + val pt = Promise[T]() + if (!middleState.trySuccess(Right((u, pt)))) { + // we can't set, so the other promise beat us here. + middleState.future.foreach { + case Left((_, pu)) => pu.success(u) + case Right((u1, _)) => // This should never happen + log.error(s"Logic error: tried to set Right($u) but Right($u1) already set") + } + } + } + + middleState.future.flatMap { + case Left((t, pu)) => pu.future.map((t, _)) + case Right((u, pt)) => pt.future.map((_, u)) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/generators.scala b/jvm/src/main/scala/xyz/driver/core/generators.scala new file mode 100644 index 0000000..d57980e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/generators.scala @@ -0,0 +1,138 @@ +package xyz.driver.core + +import enumeratum._ +import java.math.MathContext +import java.util.UUID + +import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} +import xyz.driver.core.date.{Date, DayOfWeek} + +import scala.reflect.ClassTag +import scala.util.Random +import eu.timepit.refined.refineV +import eu.timepit.refined.api.Refined +import eu.timepit.refined.collection._ + +object generators { + + private val random = new Random + import random._ + private val secureRandom = new java.security.SecureRandom() + + private val DefaultMaxLength = 10 + private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet + private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" + private val Numbers = "0123456789" + + private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { + val builder = new StringBuilder + for (_ <- 0 until length) { + builder += chars(secureRandom.nextInt(chars.length)) + } + builder.result() + } + + /** Creates a random invitation token. + * + * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It + * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) + + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) + + def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue + + def nextBoolean(): Boolean = random.nextBoolean() + + def nextDouble(): Double = random.nextDouble() + + def nextId[T](): Id[T] = Id[T](nextUuid().toString) + + def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) + + def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) + + def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) + + def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) + + def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = + NonEmptyName[T](nextNonEmptyString(maxLength)) + + def nextUuid(): UUID = java.util.UUID.randomUUID + + def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) + + def nextString(maxLength: Int = DefaultMaxLength): String = + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + + def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { + refineV[NonEmpty]( + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + ).right.get + } + + def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None + + def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) + + def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) + + def nextTime(): Time = Time(math.abs(nextLong() % System.currentTimeMillis)) + + def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) + + def nextTimeRange(): TimeRange = { + val oneTime = nextTime() + val anotherTime = nextTime() + + TimeRange( + Time(scala.math.min(oneTime.millis, anotherTime.millis)), + Time(scala.math.max(oneTime.millis, anotherTime.millis))) + } + + def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) + + def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) + + def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = + BigDecimal(multiplier * nextDouble, new MathContext(precision)) + + def oneOf[T](items: T*): T = oneOf(items.toSet) + + def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) + + def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) + + def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = + Array.fill(nextInt(maxLength, minLength))(generator) + + def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = + Seq.fill(nextInt(maxLength, minLength))(generator) + + def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = + Vector.fill(nextInt(maxLength, minLength))(generator) + + def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = + List.fill(nextInt(maxLength, minLength))(generator) + + def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = + seqOf(generator, maxLength, minLength).toSet + + def mapOf[K, V]( + keyGenerator: => K, + valueGenerator: => V, + maxLength: Int = DefaultMaxLength, + minLength: Int = 0): Map[K, V] = + seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap +} diff --git a/jvm/src/main/scala/xyz/driver/core/json.scala b/jvm/src/main/scala/xyz/driver/core/json.scala new file mode 100644 index 0000000..de1df31 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/json.scala @@ -0,0 +1,401 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.util.{TimeZone, UUID} + +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.model.Uri.Path +import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshaller +import enumeratum._ +import eu.timepit.refined.api.{Refined, Validate} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import spray.json._ +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.date.{Date, DayOfWeek, Month} +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.rest.errors._ +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.reflect.runtime.universe._ +import scala.util.Try + +object json { + import DefaultJsonProtocol._ + + private def UuidInPath[T]: PathMatcher1[Id[T]] = + PathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) + + def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) + case _ => Unmatched + } + } + + implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = + Unmarshaller.firstOf( + Unmarshaller.strict((JsString(_: String)) andThen reader.read), + stringToValueUnmarshaller[T] + ) + + implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { + def write(id: Id[T]) = JsString(id.value) + + def read(value: JsValue): Id[T] = value match { + case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) + case JsString(id) => Id[T](id) + case _ => throw DeserializationException("Id expects string") + } + } + + implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F]): JsonFormat[F @@ T] = new JsonFormat[F @@ T] { + import tagging._ + + override def write(obj: F @@ T): JsValue = underlying.write(obj) + + override def read(json: JsValue): F @@ T = underlying.read(json).tagged[T] + } + + def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) + case _ => Unmatched + } + } + + implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { + def write(name: Name[T]) = JsString(name.value) + + def read(value: JsValue): Name[T] = value match { + case JsString(name) => Name[T](name) + case _ => throw DeserializationException("Name expects string") + } + } + + def TimeInPath: PathMatcher1[Time] = + PathMatcher("""[+-]?\d*""".r) flatMap { string => + try Some(Time(string.toLong)) + catch { case _: IllegalArgumentException => None } + } + + implicit val timeFormat = new RootJsonFormat[Time] { + def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) + + def read(value: JsValue): Time = value match { + case JsObject(fields) => + fields + .get("timestamp") + .flatMap { + case JsNumber(millis) => Some(Time(millis.toLong)) + case _ => None + } + .getOrElse(throw DeserializationException("Time expects number")) + case _ => throw DeserializationException("Time expects number") + } + } + + implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { + private val formatter = TimeOfDay.getFormatter + def read(json: JsValue): java.time.LocalTime = json match { + case JsString(chars) => + java.time.LocalTime.parse(chars) + case _ => deserializationError(s"Expected time string got ${json.toString}") + } + + def write(obj: java.time.LocalTime): JsValue = { + JsString(obj.format(formatter)) + } + } + + implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { + override def write(obj: TimeZone): JsValue = { + JsString(obj.getID()) + } + + override def read(json: JsValue): TimeZone = json match { + case JsString(chars) => + java.util.TimeZone.getTimeZone(chars) + case _ => deserializationError(s"Expected time zone string got ${json.toString}") + } + } + + implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) + + implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) + + implicit val dateFormat = new RootJsonFormat[Date] { + def write(date: Date) = JsString(date.toString) + def read(value: JsValue): Date = value match { + case JsString(dateString) => + Date + .fromString(dateString) + .getOrElse( + throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) + case _ => throw DeserializationException(s"Date expects a string, but got $value.") + } + } + + implicit val monthFormat = new RootJsonFormat[Month] { + def write(month: Month) = JsNumber(month) + def read(value: JsValue): Month = value match { + case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) + case _ => throw DeserializationException("Expected a number from 0 to 11") + } + } + + def RevisionInPath[T]: PathMatcher1[Revision[T]] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => + Some(Revision[T](string)) + } + + implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = + Unmarshaller.strict[String, Revision[T]](Revision[T]) + + implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { + def write(revision: Revision[T]) = JsString(revision.id.toString) + + def read(value: JsValue): Revision[T] = value match { + case JsString(revision) => Revision[T](revision) + case _ => throw DeserializationException("Revision expects uuid string") + } + } + + implicit val base64Format = new RootJsonFormat[Base64] { + def write(base64Value: Base64) = JsString(base64Value.value) + + def read(value: JsValue): Base64 = value match { + case JsString(base64Value) => Base64(base64Value) + case _ => throw DeserializationException("Base64 format expects string") + } + } + + implicit val emailFormat = new RootJsonFormat[Email] { + def write(email: Email) = JsString(email.username + "@" + email.domain) + def read(json: JsValue): Email = json match { + + case JsString(value) => + Email.parse(value).getOrElse { + deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) + } + + case _ => + deserializationError("Expected string as Email, but got " + json.toString) + } + } + + implicit val phoneNumberFormat = jsonFormat2(PhoneNumber.apply) + + implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { + override def read(json: JsValue): AuthCredentials = { + json match { + case JsObject(fields) => + val emailField = fields.get("email") + val identifierField = fields.get("identifier") + val passwordField = fields.get("password") + + (emailField, identifierField, passwordField) match { + case (_, _, None) => + deserializationError("password field must be set") + case (Some(JsString(em)), _, Some(JsString(pw))) => + val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) + AuthCredentials(email.toString, pw) + case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) + case (None, None, _) => deserializationError("identifier must be provided") + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + } + + override def write(obj: AuthCredentials): JsValue = JsObject( + "identifier" -> JsString(obj.identifier), + "password" -> JsString(obj.password) + ) + } + + implicit object inetAddressFormat extends JsonFormat[InetAddress] { + override def read(json: JsValue): InetAddress = json match { + case JsString(ipString) => + Try(InetAddress.getByName(ipString)) + .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) + case _ => deserializationError(s"Expected string for IP Address, got $json") + } + + override def write(obj: InetAddress): JsValue = + JsString(obj.getHostAddress) + } + + object enumeratum { + + def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + + trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => + + implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) + + implicit val unmarshaller: Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + } + + class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { + override def read(json: JsValue): T = json match { + case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + + override def write(obj: T): JsValue = JsString(obj.entryName) + } + + private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = + deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") + } + + class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { + private val map = mapping.toMap + + override def write(value: T): JsValue = { + map.find(_._2 == value).map(_._1) match { + case Some(name) => JsString(name) + case _ => serializationError(s"Value $value is not found in the mapping $map") + } + } + + override def read(json: JsValue): T = json match { + case JsString(name) => + map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + } + + class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { + def write(valueClass: T) = JsNumber(writeValue(valueClass)) + def read(json: JsValue): T = json match { + case JsNumber(value) => create(value) + case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + class GadtJsonFormat[T: TypeTag]( + typeField: String, + typeValue: PartialFunction[T, String], + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) + extends RootJsonFormat[T] { + + def write(value: T): JsValue = { + + val valueType = typeValue.applyOrElse(value, { v: T => + deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) + }) + + val valueFormat = + jsonFormat.applyOrElse(valueType, { f: String => + deserializationError(s"No Json format for this type of $valueType") + }) + + valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { + case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) + case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") + } + } + + def read(json: JsValue): T = json match { + case JsObject(fields) => + val valueJson = JsObject(fields.filterNot(_._1 == typeField)) + fields(typeField) match { + case JsString(valueType) => + val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + }) + valueFormat.read(valueJson) + case _ => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + } + case _ => + deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + object GadtJsonFormat { + + def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { + + new GadtJsonFormat[T](typeField, typeValue, jsonFormat) + } + } + + /** + * Provides the JsonFormat for the Refined types provided by the Refined library. + * + * @see https://github.com/fthomas/refined + */ + implicit def refinedJsonFormat[T, Predicate]( + implicit valueFormat: JsonFormat[T], + validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = + new JsonFormat[Refined[T, Predicate]] { + def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) + def read(value: JsValue): T Refined Predicate = { + refineV[Predicate](valueFormat.read(value))(validate) match { + case Right(refinedValue) => refinedValue + case Left(refinementError) => deserializationError(refinementError) + } + } + } + + def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + refineV[NonEmpty](segment) match { + case Left(_) => Unmatched + case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) + } + case _ => Unmatched + } + } + + implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = + new RootJsonFormat[NonEmptyName[T]] { + def write(name: NonEmptyName[T]) = JsString(name.value.value) + + def read(value: JsValue): NonEmptyName[T] = + NonEmptyName[T](nonEmptyStringFormat.read(value)) + } + + implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = + GadtJsonFormat.create[ServiceException]("type") { + case _: InvalidInputException => "InvalidInputException" + case _: InvalidActionException => "InvalidActionException" + case _: ResourceNotFoundException => "ResourceNotFoundException" + case _: ExternalServiceException => "ExternalServiceException" + case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" + case _: DatabaseException => "DatabaseException" + } { + case "InvalidInputException" => jsonFormat(InvalidInputException, "message") + case "InvalidActionException" => jsonFormat(InvalidActionException, "message") + case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") + case "ExternalServiceException" => + jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") + case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") + case "DatabaseException" => jsonFormat(DatabaseException, "message") + } + + val jsValueToStringMarshaller: Marshaller[JsValue, String] = + Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) + + def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = + jsValueToStringMarshaller.compose[T](jsonFormat.write) + + val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = + Unmarshaller.strict[String, JsValue](value => value.parseJson) + + def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = + stringToJsValueUnmarshaller.map[T](jsonFormat.read) +} diff --git a/jvm/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala b/jvm/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala new file mode 100644 index 0000000..df21b48 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/logging/MdcExecutionContext.scala @@ -0,0 +1,31 @@ +/** Code ported from "de.geekonaut" %% "slickmdc" % "1.0.0" + * License: @see https://github.com/AVGP/slickmdc/blob/master/LICENSE + * Blog post: @see http://50linesofco.de/post/2016-07-01-slick-and-slf4j-mdc-logging-in-scala.html + */ +package xyz.driver.core.logging + +import org.slf4j.MDC +import scala.concurrent.ExecutionContext + +/** + * Execution context proxy for propagating SLF4J diagnostic context from caller thread to execution thread. + */ +class MdcExecutionContext(executionContext: ExecutionContext) extends ExecutionContext { + override def execute(runnable: Runnable): Unit = { + val callerMdc = MDC.getCopyOfContextMap + executionContext.execute(new Runnable { + def run(): Unit = { + // copy caller thread diagnostic context to execution thread + Option(callerMdc).foreach(MDC.setContextMap) + try { + runnable.run() + } finally { + // the thread might be reused, so we clean up for the next use + MDC.clear() + } + } + }) + } + + override def reportFailure(cause: Throwable): Unit = executionContext.reportFailure(cause) +} diff --git a/jvm/src/main/scala/xyz/driver/core/logging/package.scala b/jvm/src/main/scala/xyz/driver/core/logging/package.scala new file mode 100644 index 0000000..2b6fc11 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/logging/package.scala @@ -0,0 +1,7 @@ +package xyz.driver.core + +import org.slf4j.helpers.NOPLogger + +package object logging { + val NoLogger = com.typesafe.scalalogging.Logger(NOPLogger.NOP_LOGGER) +} diff --git a/jvm/src/main/scala/xyz/driver/core/messages.scala b/jvm/src/main/scala/xyz/driver/core/messages.scala new file mode 100644 index 0000000..6b1bc7e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/messages.scala @@ -0,0 +1,58 @@ +package xyz.driver.core + +import java.util.Locale + +import com.typesafe.config.{Config, ConfigException} +import com.typesafe.scalalogging.Logger + +/** + * Scala internationalization (i18n) support + */ +object messages { + + object Messages { + def messages(config: Config, log: Logger, locale: Locale = Locale.US): Messages = { + val map = config.getConfig(locale.getLanguage) + Messages(map, locale, log) + } + } + + final case class Messages(map: Config, locale: Locale, log: Logger) { + + /** + * Returns message for the key + * + * @param key key + * @return message + */ + def apply(key: String): String = { + try { + map.getString(key) + } catch { + case _: ConfigException => + log.error(s"Message with key '$key' not found for locale '${locale.getLanguage}'") + key + } + } + + /** + * Returns message for the key and formats that with parameters + * + * @example "Hello {0}!" with "Joe" will be "Hello Joe!" + * + * @param key key + * @param params params to be embedded + * @return formatted message + */ + def apply(key: String, params: Any*): String = { + + def format(formatString: String, params: Seq[Any]) = + params.zipWithIndex.foldLeft(formatString) { + case (res, (value, index)) => res.replace(s"{$index}", value.toString) + } + + val template = apply(key) + format(template, params) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/pubsub.scala b/jvm/src/main/scala/xyz/driver/core/pubsub.scala new file mode 100644 index 0000000..6d2667f --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/pubsub.scala @@ -0,0 +1,145 @@ +package xyz.driver.core + +import akka.http.scaladsl.marshalling._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer +import com.google.api.core.{ApiFutureCallback, ApiFutures} +import com.google.cloud.pubsub.v1._ +import com.google.protobuf.ByteString +import com.google.pubsub.v1._ +import com.typesafe.scalalogging.Logger + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.util.{Failure, Try} + +object pubsub { + + trait PubsubPublisher[Message] { + + type Result + + def publish(message: Message): Future[Result] + } + + class GooglePubsubPublisher[Message](projectId: String, topic: String, log: Logger, autoCreate: Boolean = true)( + implicit messageMarshaller: Marshaller[Message, String], + ex: ExecutionContext + ) extends PubsubPublisher[Message] { + + type Result = Id[PubsubMessage] + + private val topicName = ProjectTopicName.of(projectId, topic) + + private val publisher = { + if (autoCreate) { + val adminClient = TopicAdminClient.create() + val topicExists = Try(adminClient.getTopic(topicName)).isSuccess + if (!topicExists) { + adminClient.createTopic(topicName) + } + } + Publisher.newBuilder(topicName).build() + } + + override def publish(message: Message): Future[Id[PubsubMessage]] = { + + Marshal(message).to[String].flatMap { messageString => + val data = ByteString.copyFromUtf8(messageString) + val pubsubMessage = PubsubMessage.newBuilder().setData(data).build() + + val promise = Promise[Id[PubsubMessage]]() + + val messageIdFuture = publisher.publish(pubsubMessage) + + ApiFutures.addCallback( + messageIdFuture, + new ApiFutureCallback[String]() { + override def onSuccess(messageId: String): Unit = { + log.info(s"Published a message with topic $topic, message id $messageId: $messageString") + promise.complete(Try(Id[PubsubMessage](messageId))) + } + + override def onFailure(t: Throwable): Unit = { + log.warn(s"Failed to publish a message with topic $topic: $message", t) + promise.complete(Failure(t)) + } + } + ) + + promise.future + } + } + } + + class FakePubsubPublisher[Message](topicName: String, log: Logger)( + implicit messageMarshaller: Marshaller[Message, String], + ex: ExecutionContext) + extends PubsubPublisher[Message] { + + type Result = Id[PubsubMessage] + + def publish(message: Message): Future[Result] = + Marshal(message).to[String].map { messageString => + log.info(s"Published a message to a fake pubsub with topic $topicName: $messageString") + generators.nextId[PubsubMessage]() + } + } + + trait PubsubSubscriber { + + def stopListening(): Unit + } + + class GooglePubsubSubscriber[Message]( + projectId: String, + subscriptionId: String, + receiver: Message => Future[Unit], + log: Logger, + autoCreateSettings: Option[GooglePubsubSubscriber.SubscriptionSettings] = None + )(implicit messageMarshaller: Unmarshaller[String, Message], mat: Materializer, ex: ExecutionContext) + extends PubsubSubscriber { + + private val subscriptionName = ProjectSubscriptionName.of(projectId, subscriptionId) + + private val messageReceiver = new MessageReceiver() { + override def receiveMessage(message: PubsubMessage, consumer: AckReplyConsumer): Unit = { + val messageString = message.getData.toStringUtf8 + Unmarshal(messageString).to[Message].flatMap { messageBody => + log.info(s"Received a message ${message.getMessageId} for subscription $subscriptionId: $messageString") + receiver(messageBody).transform(v => { consumer.ack(); v }, t => { consumer.nack(); t }) + } + } + } + + private val subscriber = { + autoCreateSettings.foreach { subscriptionSettings => + val adminClient = SubscriptionAdminClient.create() + val subscriptionExists = Try(adminClient.getSubscription(subscriptionName)).isSuccess + if (!subscriptionExists) { + val topicName = ProjectTopicName.of(projectId, subscriptionSettings.topic) + adminClient.createSubscription( + subscriptionName, + topicName, + subscriptionSettings.pushConfig, + subscriptionSettings.ackDeadlineSeconds) + } + } + + Subscriber.newBuilder(subscriptionName, messageReceiver).build() + } + + subscriber.startAsync() + + override def stopListening(): Unit = { + subscriber.stopAsync() + } + } + + object GooglePubsubSubscriber { + final case class SubscriptionSettings(topic: String, pushConfig: PushConfig, ackDeadlineSeconds: Int) + } + + class FakePubsubSubscriber extends PubsubSubscriber { + def stopListening(): Unit = () + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..55f39ba --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,111 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model.{StatusCodes, _} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors._ + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + def log: Logger + + def route: Route + + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route + } + } + + protected def defaultResponseHeaders: Directive0 = { + extractRequest flatMap { request => + // Needs to happen before any request processing, so all the log messages + // associated with processing of this request are having this `trackingId` + val trackingId = rest.extractTrackingId(request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + MDC.put("trackingId", trackingId) + + // This header will eliminate the risk of LB trying to reuse a connection + // that already timed out on the server side by completely rejecting keep-alive + val rejectKeepAlive = Connection("close") + + respondWithHeaders(tracingHeader, rejectKeepAlive) + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case serviceException: ServiceException => + serviceExceptionHandler(serviceException) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( + ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) + + case t: Exception => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) + } + + protected def serviceExceptionHandler(serviceException: ServiceException): Route = { + val statusCode = serviceException match { + case e: InvalidInputException => + log.info("Invalid client input error", e) + StatusCodes.BadRequest + case e: InvalidActionException => + log.info("Invalid client action error", e) + StatusCodes.Forbidden + case e: ResourceNotFoundException => + log.info("Resource not found error", e) + StatusCodes.NotFound + case e: ExternalServiceException => + log.error("Error while calling another service", e) + StatusCodes.InternalServerError + case e: ExternalServiceTimeoutException => + log.error("Service timeout error", e) + StatusCodes.GatewayTimeout + case e: DatabaseException => + log.error("Database error", e) + StatusCodes.InternalServerError + } + + { (ctx: RequestContext) => + import xyz.driver.core.json.serviceExceptionFormat + val entity = + HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) + errorResponse(statusCode, entity, serviceException)(ctx) + } + } + + protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = + errorResponse(statusCode, HttpEntity(message), exception) + + protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { + complete(HttpResponse(statusCode, entity = entity)) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala new file mode 100644 index 0000000..788729a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -0,0 +1,89 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.TcpIdleTimeoutException +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.Name +import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} +import xyz.driver.core.time.provider.TimeProvider + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class HttpRestServiceTransport( + applicationName: Name[App], + applicationVersion: String, + actorSystem: ActorSystem, + executionContext: ExecutionContext, + log: Logger, + time: TimeProvider) + extends ServiceTransport { + + protected implicit val execution: ExecutionContext = executionContext + + protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + + val requestTime = time.currentTime() + + val request = requestStub + .withHeaders(context.contextHeaders.toSeq.map { + case (ContextHeaders.TrackingIdHeader, _) => + RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) + case (ContextHeaders.StacktraceHeader, _) => + RawHeader( + ContextHeaders.StacktraceHeader, + Option(MDC.get("stack")) + .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) + .getOrElse("")) + case (header, headerValue) => RawHeader(header, headerValue) + }: _*) + + log.debug(s"Sending request to ${request.method} ${request.uri}") + + val response = httpClient.makeRequest(request) + + response.onComplete { + case Success(r) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.debug(s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") + + case Failure(t: Throwable) => + val responseLatency = requestTime.durationTo(time.currentTime()) + log.warn(s"Failed to receive response from ${request.method} ${request.uri} in $responseLatency ms", t) + }(executionContext) + + response.recoverWith { + case _: TcpIdleTimeoutException => + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Future.failed(ExternalServiceTimeoutException(serviceCalled)) + case t: Throwable => Future.failed(t) + } + } + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { + + sendRequestGetResponse(context)(requestStub) flatMap { response => + if (response.status == StatusCodes.NotFound) { + Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) + } else if (response.status.isFailure()) { + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Unmarshal(response.entity).to[String] flatMap { errorString => + import spray.json._ + import xyz.driver.core.json._ + val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption + Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) + } + } else { + Future.successful(Unmarshal(response.entity)) + } + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..f33bf9d --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => + val oldValue = oldObj.fields.getOrElse(key, JsNull) + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala new file mode 100644 index 0000000..2854257 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala @@ -0,0 +1,67 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.scaladsl.{Keep, Sink, Source} +import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} +import xyz.driver.core.Name + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +class PooledHttpClient( + baseUri: Uri, + applicationName: Name[App], + applicationVersion: String, + requestRateLimit: Int = 64, + requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) + extends HttpClient { + + private val host = baseUri.authority.host.toString() + private val port = baseUri.effectivePort + private val scheme = baseUri.scheme + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + private val pool = if (scheme.equalsIgnoreCase("https")) { + Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } else { + Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } + + private val queue = Source + .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) + .via(pool) + .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) + .toMat(Sink.foreach({ + case ((Success(resp), p)) => p.success(resp) + case ((Failure(e), p)) => p.failure(e) + }))(Keep.left) + .run + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + val responsePromise = Promise[HttpResponse]() + + queue.offer(request -> responsePromise).flatMap { + case QueueOfferResult.Enqueued => + responsePromise.future + case QueueOfferResult.Dropped => + Future.failed(new Exception(s"Request queue to the host $host is overflown")) + case QueueOfferResult.Failure(ex) => + Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala new file mode 100644 index 0000000..c0e9f99 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala @@ -0,0 +1,26 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} +import com.typesafe.config.Config +import xyz.driver.core.Name + +import scala.concurrent.ExecutionContext + +trait ProxyRoute extends DriverRoute { + implicit val executionContext: ExecutionContext + val config: Config + val httpClient: HttpClient + + protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => + val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") + val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") + + val originalUri = ctx.request.uri + val originalRequest = ctx.request + + val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) + val newRequest = originalRequest.withUri(newUri) + + httpClient.makeRequest(newRequest).map(RouteResult.Complete) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala b/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala new file mode 100644 index 0000000..8d46d72 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/RestService.scala @@ -0,0 +1,72 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +trait RestService extends Service { + + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json._ + + protected implicit val exec: ExecutionContext + protected implicit val materializer: Materializer + + implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { + def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = + if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] + } + + protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = + OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) + + protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) + + protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) + + protected def jsonEntity(json: JsValue): RequestEntity = + HttpEntity(ContentTypes.`application/json`, json.compactPrint) + + protected def mergePatchJsonEntity(json: JsValue): RequestEntity = + HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) + + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) + + protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) + + protected def postJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) + + protected def putJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) + + protected def patchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) + + protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) + + protected def endpointUri(baseUri: Uri, path: String): Uri = + baseUri.withPath(Uri.Path(path)) + + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala new file mode 100644 index 0000000..964a5a2 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala @@ -0,0 +1,29 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.ActorMaterializer +import xyz.driver.core.Name + +import scala.concurrent.Future + +class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) + extends HttpClient { + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private val client = Http()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + client.singleRequest(request, settings = connectionPoolSettings) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala b/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala new file mode 100644 index 0000000..a3d942c --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -0,0 +1,127 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile +import akka.stream.ActorAttributes +import akka.stream.scaladsl.{Framing, StreamConverters} +import akka.util.ByteString +import com.github.swagger.akka.SwaggerHttpService +import com.github.swagger.akka.model._ +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import io.swagger.util.Json + +import scala.reflect.runtime.universe +import scala.reflect.runtime.universe.Type +import scala.util.control.NonFatal + +class Swagger( + override val host: String, + override val schemes: List[Scheme], + version: String, + val apiTypes: Seq[Type], + val config: Config, + val logger: Logger) + extends SwaggerHttpService { + + lazy val mirror = universe.runtimeMirror(getClass.getClassLoader) + + override val apiClasses = apiTypes.map { tpe => + mirror.runtimeClass(tpe.typeSymbol.asClass) + }.toSet + + // Note that the reason for overriding this is a subtle chain of causality: + // + // 1. Some of our endpoints require a single trailing slash and will not + // function if it is omitted + // 2. Swagger omits trailing slashes in its generated api doc + // 3. To work around that, a space is added after the trailing slash in the + // swagger Path annotations + // 4. This space is removed manually in the code below + // + // TODO: Ideally we'd like to drop this custom override and fix the issue in + // 1, by dropping the slash requirement and accepting api endpoints with and + // without trailing slashes. This will require inspecting and potentially + // fixing all service endpoints. + override def generateSwaggerJson: String = { + import io.swagger.models.{Swagger => JSwagger} + + import scala.collection.JavaConverters._ + try { + val swagger: JSwagger = reader.read(apiClasses.asJava) + + // Removing trailing spaces + swagger.setPaths( + swagger.getPaths.asScala + .map { + case (key, path) => + key.trim -> path + } + .toMap + .asJava) + + Json.pretty().writeValueAsString(swagger) + } catch { + case NonFatal(t) => + logger.error("Issue with creating swagger.json", t) + throw t + } + } + + override val basePath: String = config.getString("swagger.basePath") + override val apiDocsPath: String = config.getString("swagger.docsPath") + + override val info = Info( + config.getString("swagger.apiInfo.description"), + version, + config.getString("swagger.apiInfo.title"), + config.getString("swagger.apiInfo.termsOfServiceUrl"), + contact = Some( + Contact( + config.getString("swagger.apiInfo.contact.name"), + config.getString("swagger.apiInfo.contact.url"), + config.getString("swagger.apiInfo.contact.email") + )), + license = Some( + License( + config.getString("swagger.apiInfo.license"), + config.getString("swagger.apiInfo.licenseUrl") + )), + vendorExtensions = Map.empty[String, AnyRef] + ) + + /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ + private def getTemplatedResource( + resourceName: String, + contentType: ContentType, + substitution: (String, String)): Route = get { + Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { + case Some(ResourceFile(url, length @ _, _)) => + extractSettings { settings => + val stream = StreamConverters + .fromInputStream(() => url.openStream()) + .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) + .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) + .map { line => + line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) + } + .map(line => ByteString(line + "\n")) + complete( + HttpEntity(contentType, stream) + ) + } + case None => reject + } + } + + def swaggerUI: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui") + +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala new file mode 100644 index 0000000..5007774 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val permissionsMap = permissions.map(_ -> true).toMap + Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala new file mode 100644 index 0000000..82edcc7 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala @@ -0,0 +1,73 @@ +package xyz.driver.core.rest.auth + +import akka.http.scaladsl.model.headers.HttpChallenges +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected +import com.typesafe.scalalogging.Logger +import xyz.driver.core._ +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ServiceRequestContext, serviceContext} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import scalaz.Scalaz.futureInstance +import scalaz.OptionT + +abstract class AuthProvider[U <: User](val authorization: Authorization[U], log: Logger)( + implicit execution: ExecutionContext) { + + import akka.http.scaladsl.server._ + import Directives._ + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param ctx set of request values which can be relevant to authenticate user + * @return authenticated user + */ + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] + + /** + * Verifies if a service context is authenticated and authorized to have `permissions` + */ + def authorize( + context: ServiceRequestContext, + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + onComplete { + (for { + authToken <- OptionT.optionT(Future.successful(context.authToken)) + user <- authenticatedUser(context) + authCtx = context.withAuthenticatedUser(authToken, user) + authorizationResult <- authorization.userHasPermissions(user, permissions)(authCtx).toOptionT + + cachedPermissionsAuthCtx = authorizationResult.token.fold(authCtx)(authCtx.withPermissionsToken) + allAuthorized = permissions.forall(authorizationResult.authorized.getOrElse(_, false)) + } yield (cachedPermissionsAuthCtx, allAuthorized)).run + } flatMap { + case Success(Some((authCtx, true))) => provide(authCtx) + case Success(Some((authCtx, false))) => + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + log.warn( + s"User ${authCtx.authenticatedUser} does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + case Success(None) => + val challenge = HttpChallenges.basic("Failed to authenticate user") + log.warn(s"Failed to authenticate user to verify ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + case Failure(t) => + log.warn(s"Wasn't able to verify token for authenticated user to verify ${permissions.mkString(", ")}", t) + reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) + } + } + + /** + * Verifies if request is authenticated and authorized to have `permissions` + */ + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + serviceContext flatMap { ctx => + authorize(ctx, permissions: _*) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala new file mode 100644 index 0000000..1a5e9be --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala @@ -0,0 +1,11 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala new file mode 100644 index 0000000..efe28c9 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala @@ -0,0 +1,22 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, PermissionsToken} + +import scalaz.Scalaz.mapMonoid +import scalaz.Semigroup +import scalaz.syntax.semigroup._ + +final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) +object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) + + implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { + private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) + private implicit val permissionsTokenSemigroup = + Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) + + override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { + AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala new file mode 100644 index 0000000..66de4ef --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala @@ -0,0 +1,55 @@ +package xyz.driver.core.rest.auth + +import java.nio.file.{Files, Path} +import java.security.{KeyFactory, PublicKey} +import java.security.spec.X509EncodedKeySpec + +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future +import scalaz.syntax.std.boolean._ + +class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } +} + +object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim + val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) + val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala new file mode 100644 index 0000000..131e7fc --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.Scalaz.{futureInstance, listInstance} +import scalaz.syntax.semigroup._ +import scalaz.syntax.traverse._ + +class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = + permissions.forall(permissionsMap.getOrElse(_, false)) + + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (allAuthorized(authResult.authorized)) Future.successful(authResult) + else { + authorization.userHasPermissions(user, permissions).map(authResult |+| _) + } + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala new file mode 100644 index 0000000..db289de --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala @@ -0,0 +1,23 @@ +package xyz.driver.core.rest.errors + +sealed abstract class ServiceException(val message: String) extends Exception(message) + +final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) + +final case class InvalidActionException(override val message: String = "This action is not allowed") + extends ServiceException(message) + +final case class ResourceNotFoundException(override val message: String = "Resource not found") + extends ServiceException(message) + +final case class ExternalServiceException( + serviceName: String, + serviceMessage: String, + serviceException: Option[ServiceException]) + extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") + +final case class ExternalServiceTimeoutException(serviceName: String) + extends ServiceException(s"$serviceName took too long to respond") + +final case class DatabaseException(override val message: String = "Database access error") + extends ServiceException(message) diff --git a/jvm/src/main/scala/xyz/driver/core/rest/package.scala b/jvm/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..f85c39a --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,286 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import xyz.driver.tracing.TracingDirectives + +import scala.concurrent.Future +import scala.util.Try +import scalaz.{Functor, OptionT} +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ + +trait Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] +} + +sealed trait SortingOrder +object SortingOrder { + case object Asc extends SortingOrder + case object Desc extends SortingOrder +} + +final case class SortingField(name: String, sortingOrder: SortingOrder) +final case class Sorting(sortingFields: Seq[SortingField]) + +final case class Pagination(pageSize: Int, pageNumber: Int) { + require(pageSize > 0, "Page size must be greater than zero") + require(pageNumber > 0, "Page number must be greater than zero") + + def offset: Int = pageSize * (pageNumber - 1) +} + +final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) + +object ListResponse { + + def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = + ListResponse( + items = items, + meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) + + final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) + + object Meta { + def apply(itemsCount: Int, pagination: Pagination): Meta = + Meta(itemsCount, pagination.pageNumber, pagination.pageSize) + } + +} + +object `package` { + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } + + object ContextHeaders { + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val ClientFingerprintHeader: String = "X-Client-Fingerprint" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val OriginatingIpHeader: String = "X-Forwarded-For" + val ResourceCount: String = "X-Resource-Count" + val PageCount: String = "X-Page-Count" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName + } + + object AuthProvider { + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" + } + + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.ClientFingerprintHeader, + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + ContextHeaders.OriginatingIpHeader, + ContextHeaders.ResourceCount, + ContextHeaders.PageCount, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader + ) + + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + def serviceContext: Directive1[ServiceRequestContext] = { + extractClientIP flatMap { remoteAddress => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } + } + + def respondWithCorsAllowedHeaders: Directive0 = { + respondWithHeaders( + List[HttpHeader]( + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + )) + } + + def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { + respondWithHeader { + `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) + } + } + + def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { + respondWithHeaders( + List[HttpHeader]( + Allow(methods.to[collection.immutable.Seq]), + `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) + )) + } + + def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = + new ServiceRequestContext( + extractTrackingId(request), + extractOriginatingIP(request, remoteAddress), + extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name === ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractFingerprintHash(request: HttpRequest): Option[String] = { + request.headers + .find(_.name === ContextHeaders.ClientFingerprintHeader) + .map(_.value()) + } + + def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { + request.headers + .find(_.name === ContextHeaders.OriginatingIpHeader) + .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) + .orElse(remoteAddress.toOption) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || + h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader + } map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + + val paginated: Directive1[Pagination] = + parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) + + private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = + (pageSizeOpt, pageNumberOpt) match { + case (Some(size), Some(number)) => Option(Pagination(size, number)) + case (None, None) => Option.empty[Pagination] + case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") + } + + val optionalPagination: Directive1[Option[Pagination]] = + parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) + + def paginationQuery(pagination: Pagination) = + Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) + + private def extractSorting(sortingString: Option[String]): Sorting = { + val sortingFields = sortingString.fold(Seq.empty[SortingField])( + _.split(",") + .filter(_.length > 0) + .map { sortingParam => + if (sortingParam.startsWith("-")) { + SortingField(sortingParam.substring(1), SortingOrder.Desc) + } else { + val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam + SortingField(fieldName, SortingOrder.Asc) + } + } + .toSeq) + + Sorting(sortingFields) + } + + val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) + + def sortingQuery(sorting: Sorting): Seq[(String, String)] = { + val sortingString = sorting.sortingFields + .map { sortingField => + sortingField.sortingOrder match { + case SortingOrder.Asc => sortingField.name + case SortingOrder.Desc => s"-${sortingField.name}" + } + } + .mkString(",") + Seq("sort" -> sortingString) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala new file mode 100644 index 0000000..55f1a2e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala @@ -0,0 +1,24 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} + +trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala new file mode 100644 index 0000000..775106e --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -0,0 +1,74 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators + +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ + +class ServiceRequestContext( + val trackingId: String = generators.nextUuid().toString, + val originatingIp: Option[InetAddress] = None, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, originatingIp, contextHeaders) + .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => + trackingId === ctx.trackingId && + originatingIp == originatingIp && + contextHeaders === ctx.contextHeaders + case _ => false + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} + +class AuthorizedServiceRequestContext[U <: User]( + override val trackingId: String = generators.nextUuid().toString, + override val originatingIp: Option[InetAddress] = None, + override val contextHeaders: Map[String, String] = Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} diff --git a/jvm/src/main/scala/xyz/driver/core/stats.scala b/jvm/src/main/scala/xyz/driver/core/stats.scala new file mode 100644 index 0000000..dbcf6e4 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/stats.scala @@ -0,0 +1,58 @@ +package xyz.driver.core + +import java.io.File +import java.lang.management.ManagementFactory +import java.lang.reflect.Modifier + +object stats { + + final case class MemoryStats(free: Long, total: Long, max: Long) + + final case class GarbageCollectorStats(totalGarbageCollections: Long, garbageCollectionTime: Long) + + final case class FileRootSpace(path: String, totalSpace: Long, freeSpace: Long, usableSpace: Long) + + object SystemStats { + + def memoryUsage: MemoryStats = { + val runtime = Runtime.getRuntime + MemoryStats(runtime.freeMemory, runtime.totalMemory, runtime.maxMemory) + } + + def availableProcessors: Int = { + Runtime.getRuntime.availableProcessors() + } + + def garbageCollectorStats: GarbageCollectorStats = { + import scala.collection.JavaConverters._ + + val (totalGarbageCollections, garbageCollectionTime) = + ManagementFactory.getGarbageCollectorMXBeans.asScala.foldLeft(0L -> 0L) { + case ((total, collectionTime), gc) => + (total + math.max(0L, gc.getCollectionCount)) -> (collectionTime + math.max(0L, gc.getCollectionTime)) + } + + GarbageCollectorStats(totalGarbageCollections, garbageCollectionTime) + } + + def fileSystemSpace: Array[FileRootSpace] = { + File.listRoots() map { root => + FileRootSpace(root.getAbsolutePath, root.getTotalSpace, root.getFreeSpace, root.getUsableSpace) + } + } + + def operatingSystemStats: Map[String, String] = { + val operatingSystemMXBean = ManagementFactory.getOperatingSystemMXBean + operatingSystemMXBean.getClass.getDeclaredMethods + .map(method => { method.setAccessible(true); method }) + .filter(method => method.getName.startsWith("get") && Modifier.isPublic(method.getModifiers)) + .map { method => + try { + method.getName -> String.valueOf(method.invoke(operatingSystemMXBean)) + } catch { + case t: Throwable => method.getName -> t.getMessage + } + } toMap + } + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/storage/BlobStorage.scala b/jvm/src/main/scala/xyz/driver/core/storage/BlobStorage.scala new file mode 100644 index 0000000..ee6c5d7 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/storage/BlobStorage.scala @@ -0,0 +1,50 @@ +package xyz.driver.core.storage + +import java.net.URL +import java.nio.file.Path + +import akka.stream.scaladsl.{Sink, Source} +import akka.util.ByteString +import akka.{Done, NotUsed} + +import scala.concurrent.Future +import scala.concurrent.duration.Duration + +/** Binary key-value store, typically implemented by cloud storage. */ +trait BlobStorage { + + /** Upload data by value. */ + def uploadContent(name: String, content: Array[Byte]): Future[String] + + /** Upload data from an existing file. */ + def uploadFile(name: String, content: Path): Future[String] + + def exists(name: String): Future[Boolean] + + /** List available keys. The prefix determines which keys should be listed + * and depends on the implementation (for instance, a file system backed + * blob store will treat a prefix as a directory path). */ + def list(prefix: String): Future[Set[String]] + + /** Get all the content of a given object. */ + def content(name: String): Future[Option[Array[Byte]]] + + /** Stream data asynchronously and with backpressure. */ + def download(name: String): Future[Option[Source[ByteString, NotUsed]]] + + /** Get a sink to upload data. */ + def upload(name: String): Future[Sink[ByteString, Future[Done]]] + + /** Delete a stored value. */ + def delete(name: String): Future[String] + + /** + * Path to specified resource. Checks that the resource exists and returns None if + * it is not found. Depending on the implementation, may throw. + */ + def url(name: String): Future[Option[URL]] +} + +trait SignedBlobStorage extends BlobStorage { + def signedDownloadUrl(name: String, duration: Duration): Future[Option[URL]] +} diff --git a/jvm/src/main/scala/xyz/driver/core/storage/FileSystemBlobStorage.scala b/jvm/src/main/scala/xyz/driver/core/storage/FileSystemBlobStorage.scala new file mode 100644 index 0000000..e12c73d --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/storage/FileSystemBlobStorage.scala @@ -0,0 +1,82 @@ +package xyz.driver.core.storage + +import java.net.URL +import java.nio.file.{Files, Path, StandardCopyOption} + +import akka.stream.scaladsl.{FileIO, Sink, Source} +import akka.util.ByteString +import akka.{Done, NotUsed} + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} + +/** A blob store that is backed by a local filesystem. All objects are stored relative to the given + * root path. Slashes ('/') in blob names are treated as usual path separators and are converted + * to directories. */ +class FileSystemBlobStorage(root: Path)(implicit ec: ExecutionContext) extends BlobStorage { + + private def ensureParents(file: Path): Path = { + Files.createDirectories(file.getParent()) + file + } + + private def file(name: String) = root.resolve(name) + + override def uploadContent(name: String, content: Array[Byte]): Future[String] = Future { + Files.write(ensureParents(file(name)), content) + name + } + override def uploadFile(name: String, content: Path): Future[String] = Future { + Files.copy(content, ensureParents(file(name)), StandardCopyOption.REPLACE_EXISTING) + name + } + + override def exists(name: String): Future[Boolean] = Future { + val path = file(name) + Files.exists(path) && Files.isReadable(path) + } + + override def list(prefix: String): Future[Set[String]] = Future { + val dir = file(prefix) + Files + .list(dir) + .iterator() + .asScala + .map(p => root.relativize(p)) + .map(_.toString) + .toSet + } + + override def content(name: String): Future[Option[Array[Byte]]] = exists(name) map { + case true => + Some(Files.readAllBytes(file(name))) + case false => None + } + + override def download(name: String): Future[Option[Source[ByteString, NotUsed]]] = Future { + if (Files.exists(file(name))) { + Some(FileIO.fromPath(file(name)).mapMaterializedValue(_ => NotUsed)) + } else { + None + } + } + + override def upload(name: String): Future[Sink[ByteString, Future[Done]]] = Future { + val f = ensureParents(file(name)) + FileIO.toPath(f).mapMaterializedValue(_.map(_ => Done)) + } + + override def delete(name: String): Future[String] = exists(name).map { e => + if (e) { + Files.delete(file(name)) + } + name + } + + override def url(name: String): Future[Option[URL]] = exists(name) map { + case true => + Some(root.resolve(name).toUri.toURL) + case false => + None + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/storage/GcsBlobStorage.scala b/jvm/src/main/scala/xyz/driver/core/storage/GcsBlobStorage.scala new file mode 100644 index 0000000..95164c7 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/storage/GcsBlobStorage.scala @@ -0,0 +1,96 @@ +package xyz.driver.core.storage + +import java.io.{FileInputStream, InputStream} +import java.net.URL +import java.nio.file.Path + +import akka.Done +import akka.stream.scaladsl.Sink +import akka.util.ByteString +import com.google.api.gax.paging.Page +import com.google.auth.oauth2.ServiceAccountCredentials +import com.google.cloud.storage.Storage.BlobListOption +import com.google.cloud.storage.{Blob, BlobId, Bucket, Storage, StorageOptions} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.Duration +import scala.concurrent.{ExecutionContext, Future} + +class GcsBlobStorage(client: Storage, bucketId: String, chunkSize: Int = GcsBlobStorage.DefaultChunkSize)( + implicit ec: ExecutionContext) + extends BlobStorage with SignedBlobStorage { + + private val bucket: Bucket = client.get(bucketId) + require(bucket != null, s"Bucket $bucketId does not exist.") + + override def uploadContent(name: String, content: Array[Byte]): Future[String] = Future { + bucket.create(name, content).getBlobId.getName + } + + override def uploadFile(name: String, content: Path): Future[String] = Future { + bucket.create(name, new FileInputStream(content.toFile)).getBlobId.getName + } + + override def exists(name: String): Future[Boolean] = Future { + bucket.get(name) != null + } + + override def list(prefix: String): Future[Set[String]] = Future { + val page: Page[Blob] = bucket.list(BlobListOption.prefix(prefix)) + page + .iterateAll() + .asScala + .map(_.getName()) + .toSet + } + + override def content(name: String): Future[Option[Array[Byte]]] = Future { + Option(bucket.get(name)).map(blob => blob.getContent()) + } + + override def download(name: String) = Future { + Option(bucket.get(name)).map { blob => + ChannelStream.fromChannel(() => blob.reader(), chunkSize) + } + } + + override def upload(name: String): Future[Sink[ByteString, Future[Done]]] = Future { + val blob = bucket.create(name, Array.emptyByteArray) + ChannelStream.toChannel(() => blob.writer(), chunkSize) + } + + override def delete(name: String): Future[String] = Future { + client.delete(BlobId.of(bucketId, name)) + name + } + + override def signedDownloadUrl(name: String, duration: Duration): Future[Option[URL]] = Future { + Option(bucket.get(name)).map(blob => blob.signUrl(duration.length, duration.unit)) + } + + override def url(name: String): Future[Option[URL]] = Future { + val protocol: String = "https" + val resourcePath: String = s"storage.googleapis.com/${bucket.getName}/" + Option(bucket.get(name)).map { blob => + new URL(protocol, resourcePath, blob.getName) + } + } +} + +object GcsBlobStorage { + final val DefaultChunkSize = 8192 + + private def newClient(key: InputStream): Storage = + StorageOptions + .newBuilder() + .setCredentials(ServiceAccountCredentials.fromStream(key)) + .build() + .getService() + + def fromKeyfile(keyfile: Path, bucketId: String, chunkSize: Int = DefaultChunkSize)( + implicit ec: ExecutionContext): GcsBlobStorage = { + val client = newClient(new FileInputStream(keyfile.toFile)) + new GcsBlobStorage(client, bucketId, chunkSize) + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/storage/channelStreams.scala b/jvm/src/main/scala/xyz/driver/core/storage/channelStreams.scala new file mode 100644 index 0000000..fc652be --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/storage/channelStreams.scala @@ -0,0 +1,112 @@ +package xyz.driver.core.storage + +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} + +import akka.stream._ +import akka.stream.scaladsl.{Sink, Source} +import akka.stream.stage._ +import akka.util.ByteString +import akka.{Done, NotUsed} + +import scala.concurrent.{Future, Promise} +import scala.util.control.NonFatal + +class ChannelSource(createChannel: () => ReadableByteChannel, chunkSize: Int) + extends GraphStage[SourceShape[ByteString]] { + + val out = Outlet[ByteString]("ChannelSource.out") + val shape = SourceShape(out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) { + val channel = createChannel() + + object Handler extends OutHandler { + override def onPull(): Unit = { + try { + val buffer = ByteBuffer.allocate(chunkSize) + if (channel.read(buffer) > 0) { + buffer.flip() + push(out, ByteString.fromByteBuffer(buffer)) + } else { + completeStage() + } + } catch { + case NonFatal(_) => + channel.close() + } + } + override def onDownstreamFinish(): Unit = { + channel.close() + } + } + + setHandler(out, Handler) + } + +} + +class ChannelSink(createChannel: () => WritableByteChannel, chunkSize: Int) + extends GraphStageWithMaterializedValue[SinkShape[ByteString], Future[Done]] { + + val in = Inlet[ByteString]("ChannelSink.in") + val shape = SinkShape(in) + + override def createLogicAndMaterializedValue(inheritedAttributes: Attributes): (GraphStageLogic, Future[Done]) = { + val promise = Promise[Done]() + val logic = new GraphStageLogic(shape) { + val channel = createChannel() + + object Handler extends InHandler { + override def onPush(): Unit = { + try { + val data = grab(in) + channel.write(data.asByteBuffer) + pull(in) + } catch { + case NonFatal(e) => + channel.close() + promise.failure(e) + } + } + + override def onUpstreamFinish(): Unit = { + channel.close() + completeStage() + promise.success(Done) + } + + override def onUpstreamFailure(ex: Throwable): Unit = { + channel.close() + promise.failure(ex) + } + } + + setHandler(in, Handler) + + override def preStart(): Unit = { + pull(in) + } + } + (logic, promise.future) + } + +} + +object ChannelStream { + + def fromChannel(channel: () => ReadableByteChannel, chunkSize: Int = 8192): Source[ByteString, NotUsed] = { + Source + .fromGraph(new ChannelSource(channel, chunkSize)) + .withAttributes(Attributes(ActorAttributes.IODispatcher)) + .async + } + + def toChannel(channel: () => WritableByteChannel, chunkSize: Int = 8192): Sink[ByteString, Future[Done]] = { + Sink + .fromGraph(new ChannelSink(channel, chunkSize)) + .withAttributes(Attributes(ActorAttributes.IODispatcher)) + .async + } + +} diff --git a/jvm/src/main/scala/xyz/driver/core/swagger.scala b/jvm/src/main/scala/xyz/driver/core/swagger.scala new file mode 100644 index 0000000..6567290 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/swagger.scala @@ -0,0 +1,161 @@ +package xyz.driver.core + +import java.lang.annotation.Annotation +import java.lang.reflect.Type +import java.util + +import com.fasterxml.jackson.databind.{BeanDescription, ObjectMapper} +import com.fasterxml.jackson.databind.`type`.ReferenceType +import io.swagger.converter._ +import io.swagger.jackson.AbstractModelConverter +import io.swagger.models.{Model, ModelImpl} +import io.swagger.models.properties._ +import io.swagger.util.{Json, PrimitiveType} +import spray.json._ + +object swagger { + + def configureCustomSwaggerModels( + customPropertiesExamples: Map[Class[_], Property], + customObjectsExamples: Map[Class[_], JsValue]) = { + ModelConverters + .getInstance() + .addConverter(new CustomSwaggerJsonConverter(Json.mapper(), customPropertiesExamples, customObjectsExamples)) + } + + object CustomSwaggerJsonConverter { + + def stringProperty(pattern: Option[String] = None, example: Option[String] = None): Property = { + make(new StringProperty()) { sp => + sp.required(true) + example.foreach(sp.example) + pattern.foreach(sp.pattern) + } + } + + def enumProperty[V](values: V*): Property = { + make(new StringProperty()) { sp => + for (v <- values) sp._enum(v.toString) + sp.setRequired(true) + } + } + + def numericProperty(example: Option[AnyRef] = None): Property = { + make(PrimitiveType.DECIMAL.createProperty()) { dp => + dp.setRequired(true) + example.foreach(dp.setExample) + } + } + + def booleanProperty(): Property = { + make(new BooleanProperty()) { bp => + bp.setRequired(true) + } + } + } + + @SuppressWarnings(Array("org.wartremover.warts.Null")) + class CustomSwaggerJsonConverter( + mapper: ObjectMapper, + customProperties: Map[Class[_], Property], + customObjects: Map[Class[_], JsValue]) + extends AbstractModelConverter(mapper) { + import CustomSwaggerJsonConverter._ + + override def resolveProperty( + `type`: Type, + context: ModelConverterContext, + annotations: Array[Annotation], + chain: util.Iterator[ModelConverter]): Property = { + val javaType = Json.mapper().constructType(`type`) + + Option(javaType.getRawClass) flatMap { cls => + customProperties.get(cls) + } orElse { + `type` match { + case rt: ReferenceType if isOption(javaType.getRawClass) && chain.hasNext => + val nextType = rt.getContentType + val nextResolved = Option(resolveProperty(nextType, context, annotations, chain)).getOrElse( + chain.next().resolveProperty(nextType, context, annotations, chain)) + nextResolved.setRequired(false) + Option(nextResolved) + case t if chain.hasNext => + val nextResolved = chain.next().resolveProperty(t, context, annotations, chain) + nextResolved.setRequired(true) + Option(nextResolved) + case _ => + Option.empty[Property] + } + } orNull + } + + @SuppressWarnings(Array("org.wartremover.warts.Null")) + override def resolve(`type`: Type, context: ModelConverterContext, chain: util.Iterator[ModelConverter]): Model = { + + val javaType = Json.mapper().constructType(`type`) + + (getEnumerationInstance(javaType.getRawClass) match { + case Some(_) => Option.empty[Model] // ignore scala enums + case None => + val customObjectModel = customObjects.get(javaType.getRawClass).map { objectExampleJson => + val properties = objectExampleJson.asJsObject.fields.mapValues(parseJsonValueToSwaggerProperty).flatMap { + case (key, value) => value.map(v => key -> v) + } + + val beanDesc = _mapper.getSerializationConfig.introspect[BeanDescription](javaType) + val name = _typeName(javaType, beanDesc) + + make(new ModelImpl()) { model => + model.name(name) + properties.foreach { case (field, property) => model.addProperty(field, property) } + } + } + + customObjectModel.orElse { + if (chain.hasNext) { + val next = chain.next() + Option(next.resolve(`type`, context, chain)) + } else { + Option.empty[Model] + } + } + }).orNull + } + + private def parseJsonValueToSwaggerProperty(jsValue: JsValue): Option[Property] = { + import scala.collection.JavaConverters._ + + jsValue match { + case JsArray(elements) => + elements.headOption.flatMap(parseJsonValueToSwaggerProperty).map { itemProperty => + new ArrayProperty(itemProperty) + } + case JsObject(subFields) => + val subProperties = subFields.mapValues(parseJsonValueToSwaggerProperty).flatMap { + case (key, value) => value.map(v => key -> v) + } + Option(new ObjectProperty(subProperties.asJava)) + case JsBoolean(_) => Option(booleanProperty()) + case JsNumber(value) => Option(numericProperty(example = Option(value))) + case JsString(value) => Option(stringProperty(example = Option(value))) + case _ => Option.empty[Property] + } + } + + private def getEnumerationInstance(cls: Class[_]): Option[Enumeration] = { + if (cls.getFields.map(_.getName).contains("MODULE$")) { + val javaUniverse = scala.reflect.runtime.universe + val m = javaUniverse.runtimeMirror(Thread.currentThread().getContextClassLoader) + val moduleMirror = m.reflectModule(m.staticModule(cls.getName)) + moduleMirror.instance match { + case enumInstance: Enumeration => Some(enumInstance) + case _ => None + } + } else { + None + } + } + + private def isOption(cls: Class[_]): Boolean = cls.equals(classOf[scala.Option[_]]) + } +} diff --git a/jvm/src/main/scala/xyz/driver/core/time.scala b/jvm/src/main/scala/xyz/driver/core/time.scala new file mode 100644 index 0000000..6dbd173 --- /dev/null +++ b/jvm/src/main/scala/xyz/driver/core/time.scala @@ -0,0 +1,175 @@ +package xyz.driver.core + +import java.text.SimpleDateFormat +import java.util._ +import java.util.concurrent.TimeUnit + +import xyz.driver.core.date.Month + +import scala.concurrent.duration._ +import scala.util.Try + +object time { + + // The most useful time units + val Second = 1000L + val Seconds = Second + val Minute = 60 * Seconds + val Minutes = Minute + val Hour = 60 * Minutes + val Hours = Hour + val Day = 24 * Hours + val Days = Day + val Week = 7 * Days + val Weeks = Week + + final case class Time(millis: Long) extends AnyVal { + + def isBefore(anotherTime: Time): Boolean = implicitly[Ordering[Time]].lt(this, anotherTime) + + def isAfter(anotherTime: Time): Boolean = implicitly[Ordering[Time]].gt(this, anotherTime) + + def advanceBy(duration: Duration): Time = Time(millis + duration.toMillis) + + def durationTo(anotherTime: Time): Duration = Duration.apply(anotherTime.millis - millis, TimeUnit.MILLISECONDS) + + def durationFrom(anotherTime: Time): Duration = Duration.apply(millis - anotherTime.millis, TimeUnit.MILLISECONDS) + + def toDate(timezone: TimeZone): date.Date = { + val cal = Calendar.getInstance(timezone) + cal.setTimeInMillis(millis) + date.Date(cal.get(Calendar.YEAR), date.Month(cal.get(Calendar.MONTH)), cal.get(Calendar.DAY_OF_MONTH)) + } + } + + /** + * Encapsulates a time and timezone without a specific date. + */ + final case class TimeOfDay(localTime: java.time.LocalTime, timeZone: TimeZone) { + + /** + * Is this time before another time on a specific day. Day light savings safe. These are zero-indexed + * for month/day. + */ + def isBefore(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).before(other.toCalendar(day, month, year)) + } + + /** + * Is this time after another time on a specific day. Day light savings safe. + */ + def isAfter(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).after(other.toCalendar(day, month, year)) + } + + def sameTimeAs(other: TimeOfDay, day: Int, month: Month, year: Int): Boolean = { + toCalendar(day, month, year).equals(other.toCalendar(day, month, year)) + } + + /** + * Enforces the same formatting as expected by [[java.sql.Time]] + * @return string formatted for `java.sql.Time` + */ + def timeString: String = { + localTime.format(TimeOfDay.getFormatter) + } + + /** + * @return a string parsable by [[java.util.TimeZone]] + */ + def timeZoneString: String = { + timeZone.getID + } + + /** + * @return this [[TimeOfDay]] as [[java.sql.Time]] object, [[java.sql.Time.valueOf]] will + * throw when the string is not valid, but this is protected by [[timeString]] method. + */ + def toTime: java.sql.Time = { + java.sql.Time.valueOf(timeString) + } + + private def toCalendar(day: Int, month: Int, year: Int): Calendar = { + val cal = Calendar.getInstance(timeZone) + cal.set(year, month, day, localTime.getHour, localTime.getMinute, localTime.getSecond) + cal.clear(Calendar.MILLISECOND) + cal + } + } + + object TimeOfDay { + def now(): TimeOfDay = { + TimeOfDay(java.time.LocalTime.now(), TimeZone.getDefault) + } + + /** + * Throws when [s] is not parsable by [[java.time.LocalTime.parse]], uses default [[java.util.TimeZone]] + */ + def parseTimeString(tz: TimeZone = TimeZone.getDefault)(s: String): TimeOfDay = { + TimeOfDay(java.time.LocalTime.parse(s), tz) + } + + def fromString(tz: TimeZone)(s: String): Option[TimeOfDay] = { + val op = Try(java.time.LocalTime.parse(s)).toOption + op.map(lt => TimeOfDay(lt, tz)) + } + + def fromStrings(zoneId: String)(s: String): Option[TimeOfDay] = { + val op = Try(TimeZone.getTimeZone(zoneId)).toOption + op.map(tz => TimeOfDay.parseTimeString(tz)(s)) + } + + /** + * Formatter that enforces `HH:mm:ss` which is expected by [[java.sql.Time]] + */ + def getFormatter: java.time.format.DateTimeFormatter = { + java.time.format.DateTimeFormatter.ofPattern("HH:mm:ss") + } + } + + object Time { + + implicit def timeOrdering: Ordering[Time] = Ordering.by(_.millis) + } + + final case class TimeRange(start: Time, end: Time) { + def duration: Duration = FiniteDuration(end.millis - start.millis, MILLISECONDS) + } + + def startOfMonth(time: Time) = { + Time(make(new GregorianCalendar()) { cal => + cal.setTime(new Date(time.millis)) + cal.set(Calendar.DAY_OF_MONTH, cal.getActualMinimum(Calendar.DAY_OF_MONTH)) + }.getTime.getTime) + } + + def textualDate(timezone: TimeZone)(time: Time): String = + make(new SimpleDateFormat("MMMM d, yyyy"))(_.setTimeZone(timezone)).format(new Date(time.millis)) + + def textualTime(timezone: TimeZone)(time: Time): String = + make(new SimpleDateFormat("MMM dd, yyyy hh:mm:ss a"))(_.setTimeZone(timezone)).format(new Date(time.millis)) + + object provider { + + /** + * Time providers are supplying code with current times + * and are extremely useful for testing to check how system is going + * to behave at specific moments in time. + * + * All the calls to receive current time must be made using time + * provider injected to the caller. + */ + trait TimeProvider { + def currentTime(): Time + } + + final class SystemTimeProvider extends TimeProvider { + def currentTime() = Time(System.currentTimeMillis()) + } + final val SystemTimeProvider = new SystemTimeProvider + + final class SpecificTimeProvider(time: Time) extends TimeProvider { + def currentTime() = time + } + } +} |