From 4d1197099ce4e721c18bf4cacbb2e1980e4210b5 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 12 Sep 2018 16:40:57 -0700 Subject: Move REST functionality to separate project --- build.sbt | 2 +- .../src/main/scala/xyz/driver/core/auth.scala | 43 ++ .../main/scala/xyz/driver/core/generators.scala | 143 ++++++ .../src/main/scala/xyz/driver/core/json.scala | 398 ++++++++++++++++ .../scala/xyz/driver/core/rest/DnsDiscovery.scala | 11 + .../scala/xyz/driver/core/rest/DriverRoute.scala | 122 +++++ .../core/rest/HttpRestServiceTransport.scala | 103 ++++ .../xyz/driver/core/rest/PatchDirectives.scala | 104 ++++ .../xyz/driver/core/rest/PooledHttpClient.scala | 67 +++ .../scala/xyz/driver/core/rest/ProxyRoute.scala | 26 + .../scala/xyz/driver/core/rest/RestService.scala | 86 ++++ .../xyz/driver/core/rest/ServiceDescriptor.scala | 16 + .../driver/core/rest/SingleRequestHttpClient.scala | 29 ++ .../main/scala/xyz/driver/core/rest/Swagger.scala | 144 ++++++ .../core/rest/auth/AlwaysAllowAuthorization.scala | 14 + .../xyz/driver/core/rest/auth/AuthProvider.scala | 75 +++ .../xyz/driver/core/rest/auth/Authorization.scala | 11 + .../core/rest/auth/AuthorizationResult.scala | 22 + .../core/rest/auth/CachedTokenAuthorization.scala | 55 +++ .../core/rest/auth/ChainedAuthorization.scala | 27 ++ .../core/rest/directives/AuthDirectives.scala | 19 + .../core/rest/directives/CorsDirectives.scala | 72 +++ .../driver/core/rest/directives/Directives.scala | 6 + .../driver/core/rest/directives/PathMatchers.scala | 85 ++++ .../core/rest/directives/Unmarshallers.scala | 40 ++ .../driver/core/rest/errors/serviceException.scala | 27 ++ .../xyz/driver/core/rest/headers/Traceparent.scala | 33 ++ .../main/scala/xyz/driver/core/rest/package.scala | 323 +++++++++++++ .../xyz/driver/core/rest/serviceDiscovery.scala | 24 + .../driver/core/rest/serviceRequestContext.scala | 87 ++++ .../src/test/scala/xyz/driver/core/AuthTest.scala | 165 +++++++ .../scala/xyz/driver/core/GeneratorsTest.scala | 264 +++++++++++ .../src/test/scala/xyz/driver/core/JsonTest.scala | 521 +++++++++++++++++++++ .../src/test/scala/xyz/driver/core/TestTypes.scala | 14 + .../scala/xyz/driver/core/rest/DriverAppTest.scala | 89 ++++ .../xyz/driver/core/rest/DriverRouteTest.scala | 121 +++++ .../xyz/driver/core/rest/PatchDirectivesTest.scala | 101 ++++ .../test/scala/xyz/driver/core/rest/RestTest.scala | 151 ++++++ src/main/scala/xyz/driver/core/auth.scala | 43 -- src/main/scala/xyz/driver/core/generators.scala | 143 ------ src/main/scala/xyz/driver/core/json.scala | 398 ---------------- .../scala/xyz/driver/core/rest/DnsDiscovery.scala | 11 - .../scala/xyz/driver/core/rest/DriverRoute.scala | 122 ----- .../core/rest/HttpRestServiceTransport.scala | 103 ---- .../xyz/driver/core/rest/PatchDirectives.scala | 104 ---- .../xyz/driver/core/rest/PooledHttpClient.scala | 67 --- .../scala/xyz/driver/core/rest/ProxyRoute.scala | 26 - .../scala/xyz/driver/core/rest/RestService.scala | 86 ---- .../xyz/driver/core/rest/ServiceDescriptor.scala | 16 - .../driver/core/rest/SingleRequestHttpClient.scala | 29 -- src/main/scala/xyz/driver/core/rest/Swagger.scala | 144 ------ .../core/rest/auth/AlwaysAllowAuthorization.scala | 14 - .../xyz/driver/core/rest/auth/AuthProvider.scala | 75 --- .../xyz/driver/core/rest/auth/Authorization.scala | 11 - .../core/rest/auth/AuthorizationResult.scala | 22 - .../core/rest/auth/CachedTokenAuthorization.scala | 55 --- .../core/rest/auth/ChainedAuthorization.scala | 27 -- .../core/rest/directives/AuthDirectives.scala | 19 - .../core/rest/directives/CorsDirectives.scala | 72 --- .../driver/core/rest/directives/Directives.scala | 6 - .../driver/core/rest/directives/PathMatchers.scala | 85 ---- .../core/rest/directives/Unmarshallers.scala | 40 -- .../driver/core/rest/errors/serviceException.scala | 27 -- .../xyz/driver/core/rest/headers/Traceparent.scala | 33 -- src/main/scala/xyz/driver/core/rest/package.scala | 323 ------------- .../xyz/driver/core/rest/serviceDiscovery.scala | 24 - .../driver/core/rest/serviceRequestContext.scala | 87 ---- src/test/scala/xyz/driver/core/AuthTest.scala | 165 ------- .../scala/xyz/driver/core/GeneratorsTest.scala | 264 ----------- src/test/scala/xyz/driver/core/JsonTest.scala | 521 --------------------- src/test/scala/xyz/driver/core/TestTypes.scala | 14 - .../scala/xyz/driver/core/rest/DriverAppTest.scala | 89 ---- .../xyz/driver/core/rest/DriverRouteTest.scala | 121 ----- .../xyz/driver/core/rest/PatchDirectivesTest.scala | 101 ---- src/test/scala/xyz/driver/core/rest/RestTest.scala | 151 ------ 75 files changed, 3639 insertions(+), 3639 deletions(-) create mode 100644 core-rest/src/main/scala/xyz/driver/core/auth.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/generators.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/json.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/package.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/AuthTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/JsonTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/TestTypes.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala delete mode 100644 src/main/scala/xyz/driver/core/auth.scala delete mode 100644 src/main/scala/xyz/driver/core/generators.scala delete mode 100644 src/main/scala/xyz/driver/core/json.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/DriverRoute.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PatchDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ProxyRoute.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/RestService.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/Swagger.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/Authorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/Directives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/errors/serviceException.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/package.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala delete mode 100644 src/test/scala/xyz/driver/core/AuthTest.scala delete mode 100644 src/test/scala/xyz/driver/core/GeneratorsTest.scala delete mode 100644 src/test/scala/xyz/driver/core/JsonTest.scala delete mode 100644 src/test/scala/xyz/driver/core/TestTypes.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/DriverAppTest.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala delete mode 100644 src/test/scala/xyz/driver/core/rest/RestTest.scala diff --git a/build.sbt b/build.sbt index 403a35f..4f32af7 100644 --- a/build.sbt +++ b/build.sbt @@ -54,7 +54,7 @@ lazy val `core-types` = project lazy val `core-rest` = project .enablePlugins(LibraryPlugin) - .dependsOn(`core-util`, `core-types`) + .dependsOn(`core-util`, `core-types`, `core-reporting`) .settings(testdeps) lazy val `core-reporting` = project diff --git a/core-rest/src/main/scala/xyz/driver/core/auth.scala b/core-rest/src/main/scala/xyz/driver/core/auth.scala new file mode 100644 index 0000000..896bd89 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/auth.scala @@ -0,0 +1,43 @@ +package xyz.driver.core + +import xyz.driver.core.domain.Email +import xyz.driver.core.time.Time +import scalaz.Equal + +object auth { + + trait Permission + + final case class Role(id: Id[Role], name: Name[Role]) { + + def oneOf(roles: Role*): Boolean = roles.contains(this) + + def oneOf(roles: Set[Role]): Boolean = roles.contains(this) + } + + object Role { + implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) + } + + trait User { + def id: Id[User] + } + + final case class AuthToken(value: String) + + final case class AuthTokenUserInfo( + id: Id[User], + email: Email, + emailVerified: Boolean, + audience: String, + roles: Set[Role], + expirationTime: Time) + extends User + + final case class RefreshToken(value: String) + final case class PermissionsToken(value: String) + + final case class PasswordHash(value: String) + + final case class AuthCredentials(identifier: String, password: String) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/generators.scala b/core-rest/src/main/scala/xyz/driver/core/generators.scala new file mode 100644 index 0000000..d00b6dd --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/generators.scala @@ -0,0 +1,143 @@ +package xyz.driver.core + +import enumeratum._ +import java.math.MathContext +import java.time.{Instant, LocalDate, ZoneOffset} +import java.util.UUID + +import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} +import xyz.driver.core.date.{Date, DayOfWeek} + +import scala.reflect.ClassTag +import scala.util.Random +import eu.timepit.refined.refineV +import eu.timepit.refined.api.Refined +import eu.timepit.refined.collection._ + +object generators { + + private val random = new Random + import random._ + private val secureRandom = new java.security.SecureRandom() + + private val DefaultMaxLength = 10 + private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet + private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" + private val Numbers = "0123456789" + + private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { + val builder = new StringBuilder + for (_ <- 0 until length) { + builder += chars(secureRandom.nextInt(chars.length)) + } + builder.result() + } + + /** Creates a random invitation token. + * + * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It + * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) + + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) + + def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue + + def nextBoolean(): Boolean = random.nextBoolean() + + def nextDouble(): Double = random.nextDouble() + + def nextId[T](): Id[T] = Id[T](nextUuid().toString) + + def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) + + def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) + + def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) + + def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) + + def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = + NonEmptyName[T](nextNonEmptyString(maxLength)) + + def nextUuid(): UUID = java.util.UUID.randomUUID + + def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) + + def nextString(maxLength: Int = DefaultMaxLength): String = + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + + def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { + refineV[NonEmpty]( + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + ).right.get + } + + def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None + + def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) + + def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) + + def nextInstant(): Instant = Instant.ofEpochMilli(math.abs(nextLong() % System.currentTimeMillis)) + + def nextTime(): Time = nextInstant() + + def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) + + def nextTimeRange(): TimeRange = { + val oneTime = nextTime() + val anotherTime = nextTime() + + TimeRange( + Time(scala.math.min(oneTime.millis, anotherTime.millis)), + Time(scala.math.max(oneTime.millis, anotherTime.millis))) + } + + def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) + + def nextLocalDate(): LocalDate = nextInstant().atZone(ZoneOffset.UTC).toLocalDate + + def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) + + def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = + BigDecimal(multiplier * nextDouble, new MathContext(precision)) + + def oneOf[T](items: T*): T = oneOf(items.toSet) + + def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) + + def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) + + def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = + Array.fill(nextInt(maxLength, minLength))(generator) + + def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = + Seq.fill(nextInt(maxLength, minLength))(generator) + + def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = + Vector.fill(nextInt(maxLength, minLength))(generator) + + def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = + List.fill(nextInt(maxLength, minLength))(generator) + + def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = + seqOf(generator, maxLength, minLength).toSet + + def mapOf[K, V]( + keyGenerator: => K, + valueGenerator: => V, + maxLength: Int = DefaultMaxLength, + minLength: Int = 0): Map[K, V] = + seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap +} diff --git a/core-rest/src/main/scala/xyz/driver/core/json.scala b/core-rest/src/main/scala/xyz/driver/core/json.scala new file mode 100644 index 0000000..edc2347 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/json.scala @@ -0,0 +1,398 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.time.format.DateTimeFormatter +import java.time.{Instant, LocalDate} +import java.util.{TimeZone, UUID} + +import akka.http.scaladsl.unmarshalling.Unmarshaller +import com.neovisionaries.i18n.{CountryCode, CurrencyCode} +import enumeratum._ +import eu.timepit.refined.api.{Refined, Validate} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import spray.json._ +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.date.{Date, DayOfWeek, Month} +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} +import xyz.driver.core.rest.errors._ +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.reflect.runtime.universe._ +import scala.reflect.{ClassTag, classTag} +import scala.util.Try +import scala.util.control.NonFatal + +object json extends PathMatchers with Unmarshallers { + import DefaultJsonProtocol._ + + implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { + def write(id: Id[T]) = JsString(id.value) + + def read(value: JsValue): Id[T] = value match { + case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) + case JsString(id) => Id[T](id) + case _ => throw DeserializationException("Id expects string") + } + } + + implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F], convert: F => F @@ T = null): JsonFormat[F @@ T] = + new JsonFormat[F @@ T] { + import tagging._ + + private val transformReadValue = Option(convert).getOrElse((_: F).tagged[T]) + + override def write(obj: F @@ T): JsValue = underlying.write(obj) + + override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) + } + + implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { + def write(name: Name[T]) = JsString(name.value) + + def read(value: JsValue): Name[T] = value match { + case JsString(name) => Name[T](name) + case _ => throw DeserializationException("Name expects string") + } + } + + implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { + def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) + + def read(value: JsValue): Time = Time(instantFormat.read(value)) + } + + implicit val instantFormat: JsonFormat[Instant] = new JsonFormat[Instant] { + def write(instant: Instant): JsValue = JsString(instant.toString) + + def read(value: JsValue): Instant = value match { + case JsObject(fields) => + fields + .get("timestamp") + .flatMap { + case JsNumber(millis) => Some(Instant.ofEpochMilli(millis.longValue())) + case _ => None + } + .getOrElse(deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}")) + case JsNumber(millis) => Instant.ofEpochMilli(millis.longValue()) + case JsString(str) => + try Instant.parse(str) + catch { case NonFatal(_) => deserializationError(s"Instant expects ISO timestamp but got $str") } + case _ => deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}") + } + } + + implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { + private val formatter = TimeOfDay.getFormatter + def read(json: JsValue): java.time.LocalTime = json match { + case JsString(chars) => + java.time.LocalTime.parse(chars) + case _ => deserializationError(s"Expected time string got ${json.toString}") + } + + def write(obj: java.time.LocalTime): JsValue = { + JsString(obj.format(formatter)) + } + } + + implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { + override def write(obj: TimeZone): JsValue = { + JsString(obj.getID()) + } + + override def read(json: JsValue): TimeZone = json match { + case JsString(chars) => + java.util.TimeZone.getTimeZone(chars) + case _ => deserializationError(s"Expected time zone string got ${json.toString}") + } + } + + implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) + + implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) + + implicit val dateFormat = new RootJsonFormat[Date] { + def write(date: Date) = JsString(date.toString) + def read(value: JsValue): Date = value match { + case JsString(dateString) => + Date + .fromString(dateString) + .getOrElse( + throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) + case _ => throw DeserializationException(s"Date expects a string, but got $value.") + } + } + + implicit val localDateFormat = new RootJsonFormat[LocalDate] { + val format = DateTimeFormatter.ISO_LOCAL_DATE + + def write(date: LocalDate): JsValue = JsString(date.format(format)) + def read(value: JsValue): LocalDate = value match { + case JsString(dateString) => + try LocalDate.parse(dateString, format) + catch { + case NonFatal(_) => + throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.") + } + case _ => + throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got ${value.compactPrint}.") + } + } + + implicit val monthFormat = new RootJsonFormat[Month] { + def write(month: Month) = JsNumber(month) + def read(value: JsValue): Month = value match { + case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) + case _ => throw DeserializationException("Expected a number from 0 to 11") + } + } + + implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { + def write(revision: Revision[T]) = JsString(revision.id.toString) + + def read(value: JsValue): Revision[T] = value match { + case JsString(revision) => Revision[T](revision) + case _ => throw DeserializationException("Revision expects uuid string") + } + } + + implicit val base64Format = new RootJsonFormat[Base64] { + def write(base64Value: Base64) = JsString(base64Value.value) + + def read(value: JsValue): Base64 = value match { + case JsString(base64Value) => Base64(base64Value) + case _ => throw DeserializationException("Base64 format expects string") + } + } + + implicit val emailFormat = new RootJsonFormat[Email] { + def write(email: Email) = JsString(email.username + "@" + email.domain) + def read(json: JsValue): Email = json match { + + case JsString(value) => + Email.parse(value).getOrElse { + deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) + } + + case _ => + deserializationError("Expected string as Email, but got " + json.toString) + } + } + + implicit object phoneNumberFormat extends RootJsonFormat[PhoneNumber] { + + private val basicFormat = jsonFormat3(PhoneNumber.apply) + + def write(obj: PhoneNumber): JsValue = basicFormat.write(obj) + + def read(json: JsValue): PhoneNumber = { + val maybePhone = json match { + case JsString(number) => PhoneNumber.parse(number) + case obj: JsObject => PhoneNumber.parse(basicFormat.read(obj).toString) + case _ => None + } + maybePhone.getOrElse(deserializationError("Invalid phone number")) + } + } + + implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { + override def read(json: JsValue): AuthCredentials = { + json match { + case JsObject(fields) => + val emailField = fields.get("email") + val identifierField = fields.get("identifier") + val passwordField = fields.get("password") + + (emailField, identifierField, passwordField) match { + case (_, _, None) => + deserializationError("password field must be set") + case (Some(JsString(em)), _, Some(JsString(pw))) => + val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) + AuthCredentials(email.toString, pw) + case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) + case (None, None, _) => deserializationError("identifier must be provided") + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + } + + override def write(obj: AuthCredentials): JsValue = JsObject( + "identifier" -> JsString(obj.identifier), + "password" -> JsString(obj.password) + ) + } + + implicit object inetAddressFormat extends JsonFormat[InetAddress] { + override def read(json: JsValue): InetAddress = json match { + case JsString(ipString) => + Try(InetAddress.getByName(ipString)) + .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) + case _ => deserializationError(s"Expected string for IP Address, got $json") + } + + override def write(obj: InetAddress): JsValue = + JsString(obj.getHostAddress) + } + + implicit val countryCodeFormat: JsonFormat[CountryCode] = javaEnumFormat[CountryCode] + + implicit val currencyCodeFormat: JsonFormat[CurrencyCode] = javaEnumFormat[CurrencyCode] + + object enumeratum { + + def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + + trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => + + implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) + + implicit val unmarshaller: Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + } + + class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { + override def read(json: JsValue): T = json match { + case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + + override def write(obj: T): JsValue = JsString(obj.entryName) + } + + private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = + deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") + } + + class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { + private val map = mapping.toMap + + override def write(value: T): JsValue = { + map.find(_._2 == value).map(_._1) match { + case Some(name) => JsString(name) + case _ => serializationError(s"Value $value is not found in the mapping $map") + } + } + + override def read(json: JsValue): T = json match { + case JsString(name) => + map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + } + + def javaEnumFormat[T <: java.lang.Enum[_]: ClassTag]: JsonFormat[T] = { + val values = classTag[T].runtimeClass.asInstanceOf[Class[T]].getEnumConstants + new EnumJsonFormat[T](values.map(v => v.name() -> v): _*) + } + + class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { + def write(valueClass: T) = JsNumber(writeValue(valueClass)) + def read(json: JsValue): T = json match { + case JsNumber(value) => create(value) + case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + class GadtJsonFormat[T: TypeTag]( + typeField: String, + typeValue: PartialFunction[T, String], + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) + extends RootJsonFormat[T] { + + def write(value: T): JsValue = { + + val valueType = typeValue.applyOrElse(value, { v: T => + deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) + }) + + val valueFormat = + jsonFormat.applyOrElse(valueType, { f: String => + deserializationError(s"No Json format for this type of $valueType") + }) + + valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { + case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) + case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") + } + } + + def read(json: JsValue): T = json match { + case JsObject(fields) => + val valueJson = JsObject(fields.filterNot(_._1 == typeField)) + fields(typeField) match { + case JsString(valueType) => + val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + }) + valueFormat.read(valueJson) + case _ => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + } + case _ => + deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + object GadtJsonFormat { + + def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { + + new GadtJsonFormat[T](typeField, typeValue, jsonFormat) + } + } + + /** + * Provides the JsonFormat for the Refined types provided by the Refined library. + * + * @see https://github.com/fthomas/refined + */ + implicit def refinedJsonFormat[T, Predicate]( + implicit valueFormat: JsonFormat[T], + validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = + new JsonFormat[Refined[T, Predicate]] { + def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) + def read(value: JsValue): T Refined Predicate = { + refineV[Predicate](valueFormat.read(value))(validate) match { + case Right(refinedValue) => refinedValue + case Left(refinementError) => deserializationError(refinementError) + } + } + } + + implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = + new RootJsonFormat[NonEmptyName[T]] { + def write(name: NonEmptyName[T]) = JsString(name.value.value) + + def read(value: JsValue): NonEmptyName[T] = + NonEmptyName[T](nonEmptyStringFormat.read(value)) + } + + implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = + GadtJsonFormat.create[ServiceException]("type") { + case _: InvalidInputException => "InvalidInputException" + case _: InvalidActionException => "InvalidActionException" + case _: UnauthorizedException => "UnauthorizedException" + case _: ResourceNotFoundException => "ResourceNotFoundException" + case _: ExternalServiceException => "ExternalServiceException" + case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" + case _: DatabaseException => "DatabaseException" + } { + case "InvalidInputException" => jsonFormat(InvalidInputException, "message") + case "InvalidActionException" => jsonFormat(InvalidActionException, "message") + case "UnauthorizedException" => jsonFormat(UnauthorizedException, "message") + case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") + case "ExternalServiceException" => + jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") + case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") + case "DatabaseException" => jsonFormat(DatabaseException, "message") + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala b/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala new file mode 100644 index 0000000..87946e4 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala @@ -0,0 +1,11 @@ +package xyz.driver.core +package rest + +class DnsDiscovery(transport: HttpRestServiceTransport, overrides: Map[String, String]) { + + def discover[A](implicit descriptor: ServiceDescriptor[A]): A = { + val url = overrides.getOrElse(descriptor.name, s"https://${descriptor.name}") + descriptor.connect(transport, url) + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..911e306 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,122 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model.headers.CacheDirectives.`no-cache` +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{StatusCodes, _} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors._ + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + def log: Logger + + def route: Route + + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route + } + } + + protected def defaultResponseHeaders: Directive0 = { + extractRequest flatMap { request => + // Needs to happen before any request processing, so all the log messages + // associated with processing of this request are having this `trackingId` + val trackingId = rest.extractTrackingId(request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + MDC.put("trackingId", trackingId) + + respondWithHeaders(tracingHeader +: DriverRoute.DefaultHeaders: _*) + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case serviceException: ServiceException => + serviceExceptionHandler(serviceException) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( + ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) + + case t: Exception => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) + } + + protected def serviceExceptionHandler(serviceException: ServiceException): Route = { + val statusCode = serviceException match { + case e: InvalidInputException => + log.info("Invalid client input error", e) + StatusCodes.BadRequest + case e: InvalidActionException => + log.info("Invalid client action error", e) + StatusCodes.Forbidden + case e: UnauthorizedException => + log.info("Unauthorized user error", e) + StatusCodes.Unauthorized + case e: ResourceNotFoundException => + log.info("Resource not found error", e) + StatusCodes.NotFound + case e: ExternalServiceException => + log.error("Error while calling another service", e) + StatusCodes.InternalServerError + case e: ExternalServiceTimeoutException => + log.error("Service timeout error", e) + StatusCodes.GatewayTimeout + case e: DatabaseException => + log.error("Database error", e) + StatusCodes.InternalServerError + } + + { (ctx: RequestContext) => + import xyz.driver.core.json.serviceExceptionFormat + val entity = + HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) + errorResponse(statusCode, entity, serviceException)(ctx) + } + } + + protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = + errorResponse(statusCode, HttpEntity(message), exception) + + protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { + complete(HttpResponse(statusCode, entity = entity)) + } + +} + +object DriverRoute { + val DefaultHeaders: List[HttpHeader] = List( + // This header will eliminate the risk of envoy trying to reuse a connection + // that already timed out on the server side by completely rejecting keep-alive + Connection("close"), + // These 2 headers are the simplest way to prevent IE from caching GET requests + RawHeader("Pragma", "no-cache"), + `Cache-Control`(List(`no-cache`(Nil))) + ) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala new file mode 100644 index 0000000..e31635b --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -0,0 +1,103 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.TcpIdleTimeoutException +import org.slf4j.MDC +import xyz.driver.core.Name +import xyz.driver.core.reporting.Reporter +import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class HttpRestServiceTransport( + applicationName: Name[App], + applicationVersion: String, + val actorSystem: ActorSystem, + val executionContext: ExecutionContext, + reporter: Reporter) + extends ServiceTransport { + + protected implicit val execution: ExecutionContext = executionContext + + protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + val tags = Map( + // open tracing semantic tags + "span.kind" -> "client", + "service" -> applicationName.value, + "http.url" -> requestStub.uri.toString, + "http.method" -> requestStub.method.value, + "peer.hostname" -> requestStub.uri.authority.host.toString, + // google's tracing console provides extra search features if we define these tags + "/http/path" -> requestStub.uri.path.toString, + "/http/method" -> requestStub.method.value.toString, + "/http/url" -> requestStub.uri.toString + ) + reporter.traceAsync(s"http_call_rpc", tags) { implicit span => + val requestTime = System.currentTimeMillis() + + val request = requestStub + .withHeaders(context.contextHeaders.toSeq.map { + case (ContextHeaders.TrackingIdHeader, _) => + RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) + case (ContextHeaders.StacktraceHeader, _) => + RawHeader( + ContextHeaders.StacktraceHeader, + Option(MDC.get("stack")) + .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) + .getOrElse("")) + case (header, headerValue) => RawHeader(header, headerValue) + }: _*) + + reporter.debug(s"Sending request to ${request.method} ${request.uri}") + + val response = httpClient.makeRequest(request) + + response.onComplete { + case Success(r) => + val responseLatency = System.currentTimeMillis() - requestTime + reporter.debug( + s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") + + case Failure(t: Throwable) => + val responseLatency = System.currentTimeMillis() - requestTime + reporter.warn( + s"Failed to receive response from ${request.method.value} ${request.uri} in $responseLatency ms", + t) + }(executionContext) + + response.recoverWith { + case _: TcpIdleTimeoutException => + val serviceCalled = s"${requestStub.method.value} ${requestStub.uri}" + Future.failed(ExternalServiceTimeoutException(serviceCalled)) + case t: Throwable => Future.failed(t) + } + }(context.spanContext) + } + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { + + sendRequestGetResponse(context)(requestStub) flatMap { response => + if (response.status == StatusCodes.NotFound) { + Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) + } else if (response.status.isFailure()) { + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Unmarshal(response.entity).to[String] flatMap { errorString => + import spray.json._ + import xyz.driver.core.json._ + val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption + Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) + } + } else { + Future.successful(Unmarshal(response.entity)) + } + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..f33bf9d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => + val oldValue = oldObj.fields.getOrElse(key, JsNull) + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala new file mode 100644 index 0000000..2854257 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala @@ -0,0 +1,67 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.scaladsl.{Keep, Sink, Source} +import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} +import xyz.driver.core.Name + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +class PooledHttpClient( + baseUri: Uri, + applicationName: Name[App], + applicationVersion: String, + requestRateLimit: Int = 64, + requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) + extends HttpClient { + + private val host = baseUri.authority.host.toString() + private val port = baseUri.effectivePort + private val scheme = baseUri.scheme + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + private val pool = if (scheme.equalsIgnoreCase("https")) { + Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } else { + Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } + + private val queue = Source + .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) + .via(pool) + .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) + .toMat(Sink.foreach({ + case ((Success(resp), p)) => p.success(resp) + case ((Failure(e), p)) => p.failure(e) + }))(Keep.left) + .run + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + val responsePromise = Promise[HttpResponse]() + + queue.offer(request -> responsePromise).flatMap { + case QueueOfferResult.Enqueued => + responsePromise.future + case QueueOfferResult.Dropped => + Future.failed(new Exception(s"Request queue to the host $host is overflown")) + case QueueOfferResult.Failure(ex) => + Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala new file mode 100644 index 0000000..c0e9f99 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala @@ -0,0 +1,26 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} +import com.typesafe.config.Config +import xyz.driver.core.Name + +import scala.concurrent.ExecutionContext + +trait ProxyRoute extends DriverRoute { + implicit val executionContext: ExecutionContext + val config: Config + val httpClient: HttpClient + + protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => + val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") + val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") + + val originalUri = ctx.request.uri + val originalRequest = ctx.request + + val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) + val newRequest = originalRequest.withUri(newUri) + + httpClient.makeRequest(newRequest).map(RouteResult.Complete) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala b/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala new file mode 100644 index 0000000..09d98b8 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala @@ -0,0 +1,86 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +trait RestService extends Service { + + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json._ + + protected implicit val exec: ExecutionContext + protected implicit val materializer: Materializer + + implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { + def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = + if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] + } + + protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = + OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) + + protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) + + protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) + + protected def jsonEntity(json: JsValue): RequestEntity = + HttpEntity(ContentTypes.`application/json`, json.compactPrint) + + protected def mergePatchJsonEntity(json: JsValue): RequestEntity = + HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) + + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) + + protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) + + protected def postJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) + + protected def putJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) + + protected def patchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) + + protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) + + protected def endpointUri(baseUri: Uri, path: String): Uri = + baseUri.withPath(Uri.Path(path)) + + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) + + protected def responseToListResponse[T: JsonFormat](pagination: Option[Pagination])( + response: HttpResponse): Future[ListResponse[T]] = { + import DefaultJsonProtocol._ + val resourceCount = response.headers + .find(_.name() equalsIgnoreCase ContextHeaders.ResourceCount) + .map(_.value().toInt) + .getOrElse(0) + val meta = ListResponse.Meta(resourceCount, pagination.getOrElse(Pagination(resourceCount max 1, 1))) + Unmarshal(response.entity).to[List[T]].map(ListResponse(_, meta)) + } + + protected def responseToListResponse[T: JsonFormat](pagination: Pagination)( + response: HttpResponse): Future[ListResponse[T]] = responseToListResponse(Some(pagination))(response) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala b/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala new file mode 100644 index 0000000..646fae8 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala @@ -0,0 +1,16 @@ +package xyz.driver.core +package rest +import scala.annotation.implicitNotFound + +@implicitNotFound( + "Don't know how to communicate with service ${S}. Make sure an implicit ServiceDescriptor is" + + "available. A good place to put one is in the service's companion object.") +trait ServiceDescriptor[S] { + + /** The service's name. Must be unique among all services. */ + def name: String + + /** Get an instance of the service. */ + def connect(transport: HttpRestServiceTransport, url: String): S + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala new file mode 100644 index 0000000..964a5a2 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala @@ -0,0 +1,29 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.ActorMaterializer +import xyz.driver.core.Name + +import scala.concurrent.Future + +class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) + extends HttpClient { + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private val client = Http()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + client.singleRequest(request, settings = connectionPoolSettings) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala b/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala new file mode 100644 index 0000000..5ceac54 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -0,0 +1,144 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile +import akka.stream.ActorAttributes +import akka.stream.scaladsl.{Framing, StreamConverters} +import akka.util.ByteString +import com.github.swagger.akka.SwaggerHttpService +import com.github.swagger.akka.model._ +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import io.swagger.models.auth.{ApiKeyAuthDefinition, In} +import io.swagger.util.Json + +import scala.util.control.NonFatal + +class Swagger( + override val host: String, + accessSchemes: List[String], + version: String, + override val apiClasses: Set[Class[_]], + val config: Config, + val logger: Logger) + extends SwaggerHttpService { + + override val schemes = accessSchemes.map { s => + Scheme.forValue(s) + } + + // Note that the reason for overriding this is a subtle chain of causality: + // + // 1. Some of our endpoints require a single trailing slash and will not + // function if it is omitted + // 2. Swagger omits trailing slashes in its generated api doc + // 3. To work around that, a space is added after the trailing slash in the + // swagger Path annotations + // 4. This space is removed manually in the code below + // + // TODO: Ideally we'd like to drop this custom override and fix the issue in + // 1, by dropping the slash requirement and accepting api endpoints with and + // without trailing slashes. This will require inspecting and potentially + // fixing all service endpoints. + override def generateSwaggerJson: String = { + import io.swagger.models.{Swagger => JSwagger} + + import scala.collection.JavaConverters._ + try { + val swagger: JSwagger = reader.read(apiClasses.asJava) + + val paths = if (swagger.getPaths == null) { + Map.empty + } else { + swagger.getPaths.asScala + } + + // Removing trailing spaces + val fixedPaths = paths.map { + case (key, path) => + key.trim -> path + } + + swagger.setPaths(fixedPaths.asJava) + + Json.pretty().writeValueAsString(swagger) + } catch { + case NonFatal(t) => + logger.error("Issue with creating swagger.json", t) + throw t + } + } + + override val securitySchemeDefinitions = Map( + "token" -> { + val definition = new ApiKeyAuthDefinition("Authorization", In.HEADER) + definition.setDescription("Authentication token") + definition + } + ) + + override val basePath: String = config.getString("swagger.basePath") + override val apiDocsPath: String = config.getString("swagger.docsPath") + + override val info = Info( + config.getString("swagger.apiInfo.description"), + version, + config.getString("swagger.apiInfo.title"), + config.getString("swagger.apiInfo.termsOfServiceUrl"), + contact = Some( + Contact( + config.getString("swagger.apiInfo.contact.name"), + config.getString("swagger.apiInfo.contact.url"), + config.getString("swagger.apiInfo.contact.email") + )), + license = Some( + License( + config.getString("swagger.apiInfo.license"), + config.getString("swagger.apiInfo.licenseUrl") + )), + vendorExtensions = Map.empty[String, AnyRef] + ) + + /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ + private def getTemplatedResource( + resourceName: String, + contentType: ContentType, + substitution: (String, String)): Route = get { + Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { + case Some(ResourceFile(url, length @ _, _)) => + extractSettings { settings => + val stream = StreamConverters + .fromInputStream(() => url.openStream()) + .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) + .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) + .map { line => + line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) + } + .map(line => ByteString(line + "\n")) + complete( + HttpEntity(contentType, stream) + ) + } + case None => reject + } + } + + def swaggerUI: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui") + + def swaggerUINew: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui-dist/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui-dist") + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala new file mode 100644 index 0000000..5007774 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val permissionsMap = permissions.map(_ -> true).toMap + Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala new file mode 100644 index 0000000..e1a94e1 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala @@ -0,0 +1,75 @@ +package xyz.driver.core.rest.auth + +import akka.http.scaladsl.server.directives.Credentials +import com.typesafe.scalalogging.Logger +import scalaz.OptionT +import xyz.driver.core.auth.{AuthToken, Permission, User} +import xyz.driver.core.rest.errors.{ExternalServiceException, UnauthorizedException} +import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ContextHeaders, ServiceRequestContext, serviceContext} + +import scala.concurrent.{ExecutionContext, Future} + +abstract class AuthProvider[U <: User]( + val authorization: Authorization[U], + log: Logger, + val realm: String +)(implicit execution: ExecutionContext) { + + import akka.http.scaladsl.server._ + import Directives.{authorize => akkaAuthorize, _} + + def this(authorization: Authorization[U], log: Logger)(implicit executionContext: ExecutionContext) = + this(authorization, log, "driver.xyz") + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param ctx set of request values which can be relevant to authenticate user + * @return authenticated user + */ + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] + + protected def authenticator(context: ServiceRequestContext): AsyncAuthenticator[U] = { + case Credentials.Missing => + log.info(s"Request (${context.trackingId}) missing authentication credentials") + Future.successful(None) + case Credentials.Provided(authToken) => + authenticatedUser(context.withAuthToken(AuthToken(authToken))).run.recover({ + case ExternalServiceException(_, _, Some(UnauthorizedException(_))) => None + }) + } + + /** + * Verifies that a user agent is properly authenticated, and (optionally) authorized with the specified permissions + */ + def authorize( + context: ServiceRequestContext, + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authenticateOAuth2Async[U](realm, authenticator(context)) flatMap { authenticatedUser => + val authCtx = context.withAuthenticatedUser(context.authToken.get, authenticatedUser) + onSuccess(authorization.userHasPermissions(authenticatedUser, permissions)(authCtx)) flatMap { + case AuthorizationResult(authorized, token) => + val allAuthorized = permissions.forall(authorized.getOrElse(_, false)) + akkaAuthorize(allAuthorized) tflatMap { _ => + val cachedPermissionsCtx = token.fold(authCtx)(authCtx.withPermissionsToken) + provide(cachedPermissionsCtx) + } + } + } + } + + /** + * Verifies if request is authenticated and authorized to have `permissions` + */ + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + serviceContext flatMap (authorize(_, permissions: _*)) + } +} + +object AuthProvider { + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala new file mode 100644 index 0000000..1a5e9be --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala @@ -0,0 +1,11 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala new file mode 100644 index 0000000..efe28c9 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala @@ -0,0 +1,22 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, PermissionsToken} + +import scalaz.Scalaz.mapMonoid +import scalaz.Semigroup +import scalaz.syntax.semigroup._ + +final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) +object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) + + implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { + private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) + private implicit val permissionsTokenSemigroup = + Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) + + override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { + AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala new file mode 100644 index 0000000..66de4ef --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala @@ -0,0 +1,55 @@ +package xyz.driver.core.rest.auth + +import java.nio.file.{Files, Path} +import java.security.{KeyFactory, PublicKey} +import java.security.spec.X509EncodedKeySpec + +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future +import scalaz.syntax.std.boolean._ + +class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } +} + +object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim + val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) + val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala new file mode 100644 index 0000000..131e7fc --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.Scalaz.{futureInstance, listInstance} +import scalaz.syntax.semigroup._ +import scalaz.syntax.traverse._ + +class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = + permissions.forall(permissionsMap.getOrElse(_, false)) + + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (allAuthorized(authResult.authorized)) Future.successful(authResult) + else { + authorization.userHasPermissions(user, permissions).map(authResult |+| _) + } + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala new file mode 100644 index 0000000..ff3424d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala @@ -0,0 +1,19 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.auth.AuthProvider + +/** Authentication and authorization directives. */ +trait AuthDirectives extends AkkaDirectives { + + /** Authenticate a user based on service request headers and check if they have all given permissions. */ + def authenticateAndAuthorize[U <: User]( + authProvider: AuthProvider[U], + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authProvider.authorize(permissions: _*) + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala new file mode 100644 index 0000000..5a6bbfd --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala @@ -0,0 +1,72 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpResponse, StatusCodes} +import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} + +/** Directives to handle Cross-Origin Resource Sharing (CORS). */ +trait CorsDirectives extends AkkaDirectives { + + /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request + * origin. + * + * In a microservice environment, it can be difficult to know in advance the exact origin + * from which requests may be issued [1]. For example, the request may come from a web page served from + * any of the services, on any namespace or from other documentation sites. In general, only a set + * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that + * implement CORS require exact specification of allowed origins, including full host name and scheme, + * in order to send credentials and headers with requests to other origins. + * + * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by + * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. + * + * Note that the simplicity of this implementation comes with two notable drawbacks: + * + * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. + * + * - Allowed methods and headers can not be customized on a per-request basis. All standard + * HTTP methods are allowed, and allowed headers are specified for all inner routes. + * + * This handler is not suited for cases where more fine-grained control of responses is required. + * + * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through + * a common gateway. + * + * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed + * origins. + * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. + * @param inner Route into which CORS headers will be injected. + */ + def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { + optionalHeaderValueByType[Origin](()) { maybeOrigin => + val allowedOrigins: HttpOriginRange = maybeOrigin match { + // Note that this is not a security issue: the client will never send credentials if the allowed + // origin is set to *. This case allows us to deal with clients that do not send an origin header. + case None => HttpOriginRange.* + case Some(requestOrigin) => + val allowedOrigin = requestOrigin.origins.find(origin => + allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) + allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) + } + + respondWithHeaders( + `Access-Control-Allow-Origin`.forRange(allowedOrigins), + `Access-Control-Allow-Credentials`(true), + `Access-Control-Allow-Headers`(allowedHeaders: _*), + `Access-Control-Expose-Headers`(allowedHeaders: _*) + ) { + options { // options is used during preflight check + complete( + HttpResponse(StatusCodes.OK) + .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) + } ~ inner // in case of non-preflight check we don't do anything special + } + } + } + +} + +object CorsDirectives extends CorsDirectives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala new file mode 100644 index 0000000..0cd4ef1 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala @@ -0,0 +1,6 @@ +package xyz.driver.core +package rest +package directives + +trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers +object Directives extends Directives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala new file mode 100644 index 0000000..218c9ae --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala @@ -0,0 +1,85 @@ +package xyz.driver.core +package rest +package directives + +import java.time.Instant +import java.util.UUID + +import akka.http.scaladsl.model.Uri.Path +import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} +import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import xyz.driver.core.domain.PhoneNumber +import xyz.driver.core.time.Time + +import scala.util.control.NonFatal + +/** Akka-HTTP path matchers for custom core types. */ +trait PathMatchers { + + private def UuidInPath[T]: PathMatcher1[Id[T]] = + AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) + + def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) + case _ => Unmatched + } + } + + def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) + case _ => Unmatched + } + } + + private def timestampInPath: PathMatcher1[Long] = + PathMatcher("""[+-]?\d*""".r) flatMap { string => + try Some(string.toLong) + catch { case _: IllegalArgumentException => None } + } + + def InstantInPath: PathMatcher1[Instant] = + new PathMatcher1[Instant] { + def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { + case Path.Segment(head, tail) => + try Matched(tail, Tuple1(Instant.parse(head))) + catch { + case NonFatal(_) => Unmatched + } + case _ => Unmatched + } + } | timestampInPath.map(Instant.ofEpochMilli) + + def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) + + def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + refineV[NonEmpty](segment) match { + case Left(_) => Unmatched + case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) + } + case _ => Unmatched + } + } + + def RevisionInPath[T]: PathMatcher1[Revision[T]] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => + Some(Revision[T](string)) + } + + def PhoneInPath: PathMatcher1[PhoneNumber] = new PathMatcher1[PhoneNumber] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + PhoneNumber + .parse(segment) + .map(parsed => Matched(tail, Tuple1(parsed))) + .getOrElse(Unmatched) + case _ => Unmatched + } + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala new file mode 100644 index 0000000..6c45d15 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala @@ -0,0 +1,40 @@ +package xyz.driver.core +package rest +package directives + +import java.util.UUID + +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.unmarshalling.Unmarshaller +import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} + +/** Akka-HTTP unmarshallers for custom core types. */ +trait Unmarshallers { + + implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = + Unmarshaller.strict[String, Id[A]] { str => + Id[A](UUID.fromString(str).toString) + } + + implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = + Unmarshaller.firstOf( + Unmarshaller.strict((JsString(_: String)) andThen reader.read), + stringToValueUnmarshaller[T] + ) + + implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = + Unmarshaller.strict[String, Revision[T]](Revision[T]) + + val jsValueToStringMarshaller: Marshaller[JsValue, String] = + Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) + + def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = + jsValueToStringMarshaller.compose[T](jsonFormat.write) + + val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = + Unmarshaller.strict[String, JsValue](value => JsonParser(value)) + + def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = + stringToJsValueUnmarshaller.map[T](jsonFormat.read) + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala new file mode 100644 index 0000000..f2962c9 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.errors + +sealed abstract class ServiceException(val message: String) extends Exception(message) + +final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) + +final case class InvalidActionException(override val message: String = "This action is not allowed") + extends ServiceException(message) + +final case class UnauthorizedException( + override val message: String = "The user's authentication credentials are invalid or missing") + extends ServiceException(message) + +final case class ResourceNotFoundException(override val message: String = "Resource not found") + extends ServiceException(message) + +final case class ExternalServiceException( + serviceName: String, + serviceMessage: String, + serviceException: Option[ServiceException]) + extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") + +final case class ExternalServiceTimeoutException(serviceName: String) + extends ServiceException(s"$serviceName took too long to respond") + +final case class DatabaseException(override val message: String = "Database access error") + extends ServiceException(message) diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala new file mode 100644 index 0000000..866476d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala @@ -0,0 +1,33 @@ +package xyz.driver.core +package rest +package headers + +import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} +import xyz.driver.core.reporting.SpanContext + +import scala.util.Try + +/** Encapsulates a trace context in an HTTP header for propagation across services. + * + * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) + * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is + * ignored. + */ +final case class Traceparent(spanContext: SpanContext) extends ModeledCustomHeader[Traceparent] { + override def renderInRequests = true + override def renderInResponses = true + override val companion: Traceparent.type = Traceparent + override def value: String = f"01-${spanContext.traceId}-${spanContext.spanId}-00" +} +object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { + override val name = "traceparent" + override def parse(value: String) = Try { + val Array(version, traceId, spanId, _) = value.split("-") + require( + version == "01", + s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") + new Traceparent( + new SpanContext(traceId, spanId) + ) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/package.scala b/core-rest/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..34a4a9d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,323 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ +import scalaz.{Functor, OptionT} +import xyz.driver.core.rest.auth.AuthProvider +import xyz.driver.core.rest.errors.ExternalServiceException +import xyz.driver.core.rest.headers.Traceparent +import xyz.driver.tracing.TracingDirectives + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.Try + +trait Service + +object Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] +} + +sealed trait SortingOrder +object SortingOrder { + case object Asc extends SortingOrder + case object Desc extends SortingOrder +} + +final case class SortingField(name: String, sortingOrder: SortingOrder) +final case class Sorting(sortingFields: Seq[SortingField]) + +final case class Pagination(pageSize: Int, pageNumber: Int) { + require(pageSize > 0, "Page size must be greater than zero") + require(pageNumber > 0, "Page number must be greater than zero") + + def offset: Int = pageSize * (pageNumber - 1) +} + +final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) + +object ListResponse { + + def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = + ListResponse( + items = items, + meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) + + final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) + + object Meta { + def apply(itemsCount: Int, pagination: Pagination): Meta = + Meta(itemsCount, pagination.pageNumber, pagination.pageSize) + } + +} + +object `package` { + + implicit class FutureExtensions[T](future: Future[T]) { + def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = + future.transform(identity, { + case ExternalServiceException(_, _, Some(e)) => e + case t: Throwable => t + }) + } + + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } + + object ContextHeaders { + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val ClientFingerprintHeader: String = "X-Client-Fingerprint" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val OriginatingIpHeader: String = "X-Forwarded-For" + val ResourceCount: String = "X-Resource-Count" + val PageCount: String = "X-Page-Count" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName + } + + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.ClientFingerprintHeader, + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + ContextHeaders.OriginatingIpHeader, + ContextHeaders.ResourceCount, + ContextHeaders.PageCount, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader, + "Traceparent" + ) + + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + def serviceContext: Directive1[ServiceRequestContext] = { + def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { + headers.map({ header => + if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( + ContextHeaders.AuthenticationHeaderPrefix)) { + Authorization(OAuth2BearerToken(header.value)) + } else header + })(collection.breakOut) + } + extractClientIP flatMap { remoteAddress => + mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } + } + } + + def respondWithCorsAllowedHeaders: Directive0 = { + respondWithHeaders( + List[HttpHeader]( + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + )) + } + + def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { + respondWithHeader { + `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) + } + } + + def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { + respondWithHeaders( + List[HttpHeader]( + Allow(methods.to[collection.immutable.Seq]), + `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) + )) + } + + def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = + new ServiceRequestContext( + extractTrackingId(request), + extractOriginatingIP(request, remoteAddress), + extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name === ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractFingerprintHash(request: HttpRequest): Option[String] = { + request.headers + .find(_.name === ContextHeaders.ClientFingerprintHeader) + .map(_.value()) + } + + def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { + request.headers + .find(_.name === ContextHeaders.OriginatingIpHeader) + .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) + .orElse(remoteAddress.toOption) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers + .filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || + h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || + h.name === Traceparent.name + } + .map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } + .toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + + val paginated: Directive1[Pagination] = + parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) + + private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = + (pageSizeOpt, pageNumberOpt) match { + case (Some(size), Some(number)) => Option(Pagination(size, number)) + case (None, None) => Option.empty[Pagination] + case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") + } + + val optionalPagination: Directive1[Option[Pagination]] = + parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) + + def paginationQuery(pagination: Pagination) = + Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) + + def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( + implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { + optionalPagination { pagination => + onSuccess(handler(pagination)) { + case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => + val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) + val headers = List( + RawHeader(ContextHeaders.ResourceCount, count.toString), + RawHeader(ContextHeaders.PageCount, pageCount.toString)) + + respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) + } + } + } + + private def extractSorting(sortingString: Option[String]): Sorting = { + val sortingFields = sortingString.fold(Seq.empty[SortingField])( + _.split(",") + .filter(_.length > 0) + .map { sortingParam => + if (sortingParam.startsWith("-")) { + SortingField(sortingParam.substring(1), SortingOrder.Desc) + } else { + val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam + SortingField(fieldName, SortingOrder.Asc) + } + } + .toSeq) + + Sorting(sortingFields) + } + + val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) + + def sortingQuery(sorting: Sorting): Seq[(String, String)] = { + val sortingString = sorting.sortingFields + .map { sortingField => + sortingField.sortingOrder match { + case SortingOrder.Asc => sortingField.name + case SortingOrder.Desc => s"-${sortingField.name}" + } + } + .mkString(",") + Seq("sort" -> sortingString) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala new file mode 100644 index 0000000..55f1a2e --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala @@ -0,0 +1,24 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} + +trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala new file mode 100644 index 0000000..d2e4bc3 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -0,0 +1,87 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ +import xyz.driver.core.reporting.SpanContext +import xyz.driver.core.rest.auth.AuthProvider +import xyz.driver.core.rest.headers.Traceparent + +import scala.util.Try + +class ServiceRequestContext( + val trackingId: String = generators.nextUuid().toString, + val originatingIp: Option[InetAddress] = None, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, originatingIp, contextHeaders) + .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => + trackingId === ctx.trackingId && + originatingIp == originatingIp && + contextHeaders === ctx.contextHeaders + case _ => false + } + + def spanContext: SpanContext = { + val validHeader = Try { + contextHeaders(Traceparent.name) + }.flatMap { value => + Traceparent.parse(value) + } + validHeader.map(_.spanContext).getOrElse(SpanContext.fresh()) + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} + +class AuthorizedServiceRequestContext[U <: User]( + override val trackingId: String = generators.nextUuid().toString, + override val originatingIp: Option[InetAddress] = None, + override val contextHeaders: Map[String, String] = Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} diff --git a/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala new file mode 100644 index 0000000..2e772fb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala @@ -0,0 +1,165 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.headers.{ + HttpChallenges, + OAuth2BearerToken, + RawHeader, + Authorization => AkkaAuthorization +} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth._ +import xyz.driver.core.domain.Email +import xyz.driver.core.logging._ +import xyz.driver.core.rest._ +import xyz.driver.core.rest.auth._ +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.OptionT + +class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { + + case object TestRoleAllowedPermission extends Permission + case object TestRoleAllowedByTokenPermission extends Permission + case object TestRoleNotAllowedPermission extends Permission + + val TestRole = Role(Id("1"), Name("testRole")) + + val (publicKey, privateKey) = { + import java.security.KeyPairGenerator + + val keygen = KeyPairGenerator.getInstance("RSA") + keygen.initialize(2048) + + val keyPair = keygen.generateKeyPair() + (keyPair.getPublic, keyPair.getPrivate) + } + + val basicAuthorization: Authorization[User] = new Authorization[User] { + + override def userHasPermissions(user: User, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) + } + } + + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) + + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + + val authStatusService = new AuthProvider[User](authorization, NoLogger) { + override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = + OptionT.optionT[Future] { + if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { + Future.successful( + Some( + AuthTokenUserInfo( + Id[User]("1"), + Email("foo", "bar"), + emailVerified = true, + audience = "driver", + roles = Set(TestRole), + expirationTime = Time(1000000L) + ))) + } else { + Future.successful(Option.empty[User]) + } + } + } + + import authStatusService._ + + "'authorize' directive" should "throw error if auth token is not in the request" in { + + Get("/naive/attempt") ~> + authorize(TestRoleAllowedPermission) { user => + complete("Never going to be here") + } ~> + check { + // handled shouldBe false + rejections should contain( + AuthenticationFailedRejection( + AuthenticationFailedRejection.CredentialsMissing, + HttpChallenges.oAuth2(authStatusService.realm))) + } + } + + it should "throw error if authorized user does not have the requested permission" in { + + val referenceAuthToken = AuthToken("I am a test role's token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Post("/administration/attempt").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleNotAllowedPermission) { user => + complete("Never going to get here") + } ~> + check { + handled shouldBe false + rejections should contain(AuthorizationFailedRejection) + } + } + + it should "pass and retrieve the token to client code, if token is in request and user has permission" in { + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/valid/attempt/?a=2&b=5").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in { + val referenceAuthToken = AuthToken("unprefixed_token") + + Get("/valid/attempt/?a=2&b=5").addHeader( + RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value) + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authorize permission found in permissions token" in { + import spray.json._ + + val claim = JsObject( + Map( + "iss" -> JsString(tokenIssuer), + "sub" -> JsString("1"), + "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) + )).prettyPrint + val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/alic/attempt/?a=2&b=5") + .addHeader(referenceAuthHeader) + .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> + authorize(TestRoleAllowedByTokenPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala new file mode 100644 index 0000000..7e740a4 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala @@ -0,0 +1,264 @@ +package xyz.driver.core + +import org.scalatest.{Assertions, FlatSpec, Matchers} + +import scala.collection.immutable.IndexedSeq + +class GeneratorsTest extends FlatSpec with Matchers with Assertions { + import generators._ + + "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { + + val generatedId1 = nextId[String]() + val generatedId2 = nextId[String]() + val generatedId3 = nextId[Long]() + + generatedId1.length should be >= 0 + generatedId2.length should be >= 0 + generatedId3.length should be >= 0 + generatedId1 should not be generatedId2 + generatedId2 should !==(generatedId3) + } + + it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { + + val generatedLimitedId1 = nextId[String](5) + val generatedLimitedId2 = nextId[String](4) + val generatedLimitedId3 = nextId[Long](3) + + generatedLimitedId1.length should be >= 0 + generatedLimitedId1.length should be < 6 + generatedLimitedId2.length should be >= 0 + generatedLimitedId2.length should be < 5 + generatedLimitedId3.length should be >= 0 + generatedLimitedId3.length should be < 4 + generatedLimitedId1 should not be generatedLimitedId2 + generatedLimitedId2 should !==(generatedLimitedId3) + } + + it should "be able to generate com.drivergrp.core.Name names" in { + + Seq.fill(10)(nextName[String]()).distinct.size should be > 1 + nextName[String]().value.length should be >= 0 + + val fixedLengthName = nextName[String](10) + fixedLengthName.length should be <= 10 + assert(!fixedLengthName.value.exists(_.isControl)) + } + + it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { + + assert(nextNonEmptyName[String]().value.value.nonEmpty) + } + + it should "be able to generate proper UUIDs" in { + + nextUuid() should not be nextUuid() + nextUuid().toString.length should be(36) + } + + it should "be able to generate new Revisions" in { + + nextRevision[String]() should not be nextRevision[String]() + nextRevision[String]().id.length should be > 0 + } + + it should "be able to generate strings" in { + + nextString() should not be nextString() + nextString().length should be >= 0 + + val fixedLengthString = nextString(20) + fixedLengthString.length should be <= 20 + assert(!fixedLengthString.exists(_.isControl)) + } + + it should "be able to generate strings non-empty strings whic are non empty" in { + + assert(nextNonEmptyString().value.nonEmpty) + } + + it should "be able to generate options which are sometimes have values and sometimes not" in { + + val generatedOption = nextOption("2") + + generatedOption should not contain "1" + assert(generatedOption === Some("2") || generatedOption === None) + } + + it should "be able to generate a pair of two generated values" in { + + val constantPair = nextPair("foo", 1L) + constantPair._1 should be("foo") + constantPair._2 should be(1L) + + val generatedPair = nextPair(nextId[Int](), nextName[Int]()) + + generatedPair._1.length should be > 0 + generatedPair._2.length should be > 0 + + nextPair(nextId[Int](), nextName[Int]()) should not be + nextPair(nextId[Int](), nextName[Int]()) + } + + it should "be able to generate a triad of two generated values" in { + + val constantTriad = nextTriad("foo", "bar", 1L) + constantTriad._1 should be("foo") + constantTriad._2 should be("bar") + constantTriad._3 should be(1L) + + val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + + generatedTriad._1.length should be > 0 + generatedTriad._2.length should be > 0 + generatedTriad._3 should be >= BigDecimal(0.00) + + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + } + + it should "be able to generate a time value" in { + + val generatedTime = nextTime() + val currentTime = System.currentTimeMillis() + + generatedTime.millis should be >= 0L + generatedTime.millis should be <= currentTime + } + + it should "be able to generate a time range value" in { + + val generatedTimeRange = nextTimeRange() + val currentTime = System.currentTimeMillis() + + generatedTimeRange.start.millis should be >= 0L + generatedTimeRange.start.millis should be <= currentTime + generatedTimeRange.end.millis should be >= 0L + generatedTimeRange.end.millis should be <= currentTime + generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis + } + + it should "be able to generate a BigDecimal value" in { + + val defaultGeneratedBigDecimal = nextBigDecimal() + + defaultGeneratedBigDecimal should be >= BigDecimal(0.00) + defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) + defaultGeneratedBigDecimal.precision should be(2) + + val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) + + unitIntervalBigDecimal should be >= BigDecimal(0.00) + unitIntervalBigDecimal should be <= BigDecimal(1.00) + unitIntervalBigDecimal.precision should be(8) + } + + it should "be able to generate a specific value from a set of values" in { + + val possibleOptions = Set(1, 3, 5, 123, 0, 9) + + val pick1 = generators.oneOf(possibleOptions) + val pick2 = generators.oneOf(possibleOptions) + val pick3 = generators.oneOf(possibleOptions) + + possibleOptions should contain(pick1) + possibleOptions should contain(pick2) + possibleOptions should contain(pick3) + + val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) + + possibleOptions should contain(pick4) + possibleOptions should contain(pick5) + possibleOptions should contain(pick6) + + Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 + } + + it should "be able to generate a specific value from an enumeratum enum" in { + + import enumeratum._ + sealed trait TestEnumValue extends EnumEntry + object TestEnum extends Enum[TestEnumValue] { + case object Value1 extends TestEnumValue + case object Value2 extends TestEnumValue + case object Value3 extends TestEnumValue + case object Value4 extends TestEnumValue + val values: IndexedSeq[TestEnumValue] = findValues + } + + val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) + + TestEnum.values should contain allElementsOf picks + picks.toSet.size should be >= 1 + } + + it should "be able to generate array with values generated by generators" in { + + val arrayOfTimes = arrayOf(nextTime(), 16) + arrayOfTimes.length should be <= 16 + + val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) + arrayOfBigDecimals.length should be <= 8 + } + + it should "be able to generate seq with values generated by generators" in { + + val seqOfTimes = seqOf(nextTime(), 16) + seqOfTimes.size should be <= 16 + + val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) + seqOfBigDecimals.size should be <= 8 + } + + it should "be able to generate vector with values generated by generators" in { + + val vectorOfTimes = vectorOf(nextTime(), 16) + vectorOfTimes.size should be <= 16 + + val vectorOfStrings = seqOf(nextString(), 8) + vectorOfStrings.size should be <= 8 + } + + it should "be able to generate list with values generated by generators" in { + + val listOfTimes = listOf(nextTime(), 16) + listOfTimes.size should be <= 16 + + val listOfBigDecimals = seqOf(nextBigDecimal(), 8) + listOfBigDecimals.size should be <= 8 + } + + it should "be able to generate set with values generated by generators" in { + + val setOfTimes = vectorOf(nextTime(), 16) + setOfTimes.size should be <= 16 + + val setOfBigDecimals = seqOf(nextBigDecimal(), 8) + setOfBigDecimals.size should be <= 8 + } + + it should "be able to generate maps with keys and values generated by generators" in { + + val generatedConstantMap = mapOf("key", 123, 10) + generatedConstantMap.size should be <= 1 + assert(generatedConstantMap.keys.forall(_ == "key")) + assert(generatedConstantMap.values.forall(_ == 123)) + + val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) + assert(generatedMap.keys.forall(_.length <= 10)) + assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) + } + + it should "compose deeply" in { + + val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) + + generatedNestedMap.size should be <= 10 + generatedNestedMap.keySet.size should be <= 10 + generatedNestedMap.values.size should be <= 10 + assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala new file mode 100644 index 0000000..fd693f9 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala @@ -0,0 +1,521 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.time.{Instant, LocalDate} + +import akka.http.scaladsl.model.Uri +import akka.http.scaladsl.server.PathMatcher +import akka.http.scaladsl.server.PathMatcher.Matched +import com.neovisionaries.i18n.{CountryCode, CurrencyCode} +import enumeratum._ +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.numeric.Positive +import eu.timepit.refined.refineMV +import org.scalatest.{Inspectors, Matchers, WordSpec} +import spray.json._ +import xyz.driver.core.TestTypes.CustomGADT +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.json._ +import xyz.driver.core.json.enumeratum.HasJsonFormat +import xyz.driver.core.tagging._ +import xyz.driver.core.time.provider.SystemTimeProvider +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.collection.immutable.IndexedSeq +import scala.language.postfixOps + +class JsonTest extends WordSpec with Matchers with Inspectors { + import DefaultJsonProtocol._ + + "Json format for Id" should { + "read and write correct JSON" in { + + val referenceId = Id[String]("1312-34A") + + val writtenJson = json.idFormat.write(referenceId) + writtenJson.prettyPrint should be("\"1312-34A\"") + + val parsedId = json.idFormat.read(writtenJson) + parsedId should be(referenceId) + } + } + + "Json format for @@" should { + "read and write correct JSON" in { + trait Irrelevant + val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] + + val format = json.taggedFormat[Id[JsonTest], Irrelevant] + + val writtenJson = format.write(reference) + writtenJson shouldBe JsString("SomeID") + + val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) + parsedId shouldBe reference + } + + "read and write correct JSON when there's an implicit conversion defined" in { + val input = " some string " + + JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim() + + val trimmed: String @@ Trimmed = input + trimmed.toJson shouldBe JsString(trimmed) + } + } + + "Json format for Name" should { + "read and write correct JSON" in { + + val referenceName = Name[String]("Homer") + + val writtenJson = json.nameFormat.write(referenceName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedName = json.nameFormat.read(writtenJson) + parsedName should be(referenceName) + } + + "read and write correct JSON for Name @@ Trimmed" in { + trait Irrelevant + JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name") + + val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ") + trimmed.toJson shouldBe JsString("some name") + } + } + + "Json format for NonEmptyName" should { + "read and write correct JSON" in { + + val jsonFormat = json.nonEmptyNameFormat[String] + + val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) + + val writtenJson = jsonFormat.write(referenceNonEmptyName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedNonEmptyName = jsonFormat.read(writtenJson) + parsedNonEmptyName should be(referenceNonEmptyName) + } + } + + "Json format for Time" should { + "read and write correct JSON" in { + + val referenceTime = new SystemTimeProvider().currentTime() + + val writtenJson = json.timeFormat.write(referenceTime) + writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") + + val parsedTime = json.timeFormat.read(writtenJson) + parsedTime should be(referenceTime) + } + + "read from inputs compatible with Instant" in { + val referenceTime = new SystemTimeProvider().currentTime() + + val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString)) + + forAll(jsons) { json => + json.convertTo[Time] shouldBe referenceTime + } + } + } + + "Json format for TimeOfDay" should { + "read and write correct JSON" in { + val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") + val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") + val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) + writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) + val parsed = json.timeOfDayFormat.read(writtenJson) + parsed should be(referenceTimeOfDay) + } + } + + "Json format for Date" should { + "read and write correct JSON" in { + import date._ + + val referenceDate = Date(1941, Month.DECEMBER, 7) + + val writtenJson = json.dateFormat.write(referenceDate) + writtenJson.prettyPrint should be("\"1941-12-07\"") + + val parsedDate = json.dateFormat.read(writtenJson) + parsedDate should be(referenceDate) + } + } + + "Json format for java.time.Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + "read correct JSON when value is an epoch milli number" in { + JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an ISO timestamp string" in { + JsString(isoString).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an object with nested 'timestamp'/millis field" in { + val json = JsObject( + "timestamp" -> JsNumber(instant.toEpochMilli) + ) + + json.convertTo[Instant] shouldBe instant + } + + "write correct JSON" in { + instant.toJson shouldBe JsString(isoString) + } + } + + "Path matcher for Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + val matcher = PathMatcher("foo") / InstantInPath / + + "read instant from millis" in { + matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + + "read instant from ISO timestamp string" in { + matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + } + + "Json format for java.time.LocalDate" should { + + "read and write correct JSON" in { + val dateString = "2018-08-08" + val date = LocalDate.parse(dateString) + + date.toJson shouldBe JsString(dateString) + JsString(dateString).convertTo[LocalDate] shouldBe date + } + } + + "Json format for Revision" should { + "read and write correct JSON" in { + + val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") + + val writtenJson = json.revisionFormat.write(referenceRevision) + writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") + + val parsedRevision = json.revisionFormat.read(writtenJson) + parsedRevision should be(referenceRevision) + } + } + + "Json format for Email" should { + "read and write correct JSON" in { + + val referenceEmail = Email("test", "drivergrp.com") + + val writtenJson = json.emailFormat.write(referenceEmail) + writtenJson should be("\"test@drivergrp.com\"".parseJson) + + val parsedEmail = json.emailFormat.read(writtenJson) + parsedEmail should be(referenceEmail) + } + } + + "Json format for PhoneNumber" should { + "read and write correct JSON" in { + + val referencePhoneNumber = PhoneNumber("1", "4243039608") + + val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) + writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) + + val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) + parsedPhoneNumber should be(referencePhoneNumber) + } + + "reject an invalid phone number" in { + val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson + + intercept[DeserializationException] { + json.phoneNumberFormat.read(phoneJson) + }.getMessage shouldBe "Invalid phone number" + } + + "parse phone number from string" in { + JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608") + } + } + + "Path matcher for PhoneNumber" should { + "read valid phone number" in { + val string = "+14243039608x23" + val phone = PhoneNumber("1", "4243039608", Some("23")) + + val matcher = PathMatcher("foo") / PhoneInPath + + matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone)) + } + } + + "Json format for ADT mappings" should { + "read and write correct JSON" in { + + sealed trait EnumVal + case object Val1 extends EnumVal + case object Val2 extends EnumVal + case object Val3 extends EnumVal + + val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) + + val referenceEnumValue1 = Val2 + val referenceEnumValue2 = Val3 + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1.prettyPrint should be("\"b\"") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2.prettyPrint should be("\"c\"") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + } + } + + "Json format for Enums (external)" should { + "read and write correct JSON" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val format = new enumeratum.EnumJsonFormat(MyEnum) + + val referenceEnumValue1 = MyEnum.`Val 2` + val referenceEnumValue2 = MyEnum.`Val/3` + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2 shouldBe JsString("Val/3") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 shouldBe referenceEnumValue1 + parsedEnumValue2 shouldBe referenceEnumValue2 + + intercept[DeserializationException] { + format.read(JsString("Val4")) + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + "Json format for Enums (automatic)" should { + "read and write correct JSON and not require import" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val referenceEnumValue1: MyEnum = MyEnum.`Val 2` + val referenceEnumValue2: MyEnum = MyEnum.`Val/3` + + val writtenJson1 = referenceEnumValue1.toJson + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = referenceEnumValue2.toJson + writtenJson2 shouldBe JsString("Val/3") + + import spray.json._ + + val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] + val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + + intercept[DeserializationException] { + JsString("Val4").convertTo[MyEnum] + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + // Should be defined outside of case to have a TypeTag + case class CustomWrapperClass(value: Int) + + "Json format for Value classes" should { + "read and write correct JSON" in { + + val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) + + val referenceValue1 = CustomWrapperClass(-2) + val referenceValue2 = CustomWrapperClass(10) + + val writtenJson1 = format.write(referenceValue1) + writtenJson1.prettyPrint should be("-2") + + val writtenJson2 = format.write(referenceValue2) + writtenJson2.prettyPrint should be("10") + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for classes GADT" should { + "read and write correct JSON" in { + + import CustomGADT._ + import DefaultJsonProtocol._ + implicit val case1Format = jsonFormat1(GadtCase1) + implicit val case2Format = jsonFormat1(GadtCase2) + implicit val case3Format = jsonFormat1(GadtCase3) + + val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { + case _: CustomGADT.GadtCase1 => "case1" + case _: CustomGADT.GadtCase2 => "case2" + case _: CustomGADT.GadtCase3 => "case3" + } { + case "case1" => case1Format + case "case2" => case2Format + case "case3" => case3Format + } + + val referenceValue1 = CustomGADT.GadtCase1("4") + val referenceValue2 = CustomGADT.GadtCase2("Hi!") + + val writtenJson1 = format.write(referenceValue1) + writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) + + val writtenJson2 = format.write(referenceValue2) + writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for a Refined value" should { + "read and write correct JSON" in { + + val jsonFormat = json.refinedJsonFormat[Int, Positive] + + val referenceRefinedNumber = refineMV[Positive](42) + + val writtenJson = jsonFormat.write(referenceRefinedNumber) + writtenJson should be("42".parseJson) + + val parsedRefinedNumber = jsonFormat.read(writtenJson) + parsedRefinedNumber should be(referenceRefinedNumber) + } + } + + "InetAddress format" should { + "read and write correct JSON" in { + val address = InetAddress.getByName("127.0.0.1") + val json = inetAddressFormat.write(address) + + json shouldBe JsString("127.0.0.1") + + val parsed = inetAddressFormat.read(json) + parsed shouldBe address + } + + "throw a DeserializationException for an invalid IP Address" in { + assertThrows[DeserializationException] { + val invalidAddress = JsString("foobar:") + inetAddressFormat.read(invalidAddress) + } + } + } + + "AuthCredentials format" should { + "read and write correct JSON" in { + val email = Email("someone", "noehere.com") + val phoneId = PhoneNumber.parse("1 207 8675309") + val password = "nopassword" + + phoneId.isDefined should be(true) // test this real quick + + val emailAuth = AuthCredentials(email.toString, password) + val pnAuth = AuthCredentials(phoneId.get.toString, password) + + val emailWritten = authCredentialsFormat.write(emailAuth) + emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) + + val phoneWritten = authCredentialsFormat.write(pnAuth) + phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) + + val identifierEmailParsed = + authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) + var written = authCredentialsFormat.write(identifierEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + val emailEmailParsed = + authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) + written = authCredentialsFormat.write(emailEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + } + } + + "CountryCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "US" -> CountryCode.US, + "CN" -> CountryCode.CN, + "AT" -> CountryCode.AT + ) + + forAll(samples) { + case (serialized, enumValue) => + countryCodeFormat.write(enumValue) shouldBe JsString(serialized) + countryCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + + "CurrencyCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "USD" -> CurrencyCode.USD, + "CNY" -> CurrencyCode.CNY, + "EUR" -> CurrencyCode.EUR + ) + + forAll(samples) { + case (serialized, enumValue) => + currencyCodeFormat.write(enumValue) shouldBe JsString(serialized) + currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + +} diff --git a/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala new file mode 100644 index 0000000..bb25deb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala @@ -0,0 +1,14 @@ +package xyz.driver.core + +object TestTypes { + + sealed trait CustomGADT { + val field: String + } + + object CustomGADT { + final case class GadtCase1(field: String) extends CustomGADT + final case class GadtCase2(field: String) extends CustomGADT + final case class GadtCase3(field: String) extends CustomGADT + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..324c8d8 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,89 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpMethod, StatusCodes} +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.ConfigFactory +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, SimpleModule} + +class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { + val config = ConfigFactory.parseString(""" + |application { + | cors { + | allowedOrigins: ["example.com"] + | } + |} + """.stripMargin).withFallback(ConfigFactory.load) + + val origin = Origin(HttpOrigin("https", Host("example.com"))) + val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) + val allowedMethods: collection.immutable.Seq[HttpMethod] = { + import akka.http.scaladsl.model.HttpMethods._ + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) + } + + import scala.reflect.runtime.universe.typeOf + class TestApp(testRoute: Route) + extends DriverApp( + appName = "test-app", + version = "0.0.1", + gitHash = "deadb33f", + modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), + config = config, + log = xyz.driver.core.logging.NoLogger + ) + + it should "respond with the correct CORS headers for the swagger OPTIONS route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for the test route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "respond with the correct CORS headers for a concatenated route" in { + val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) + Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "allow subdomains of allowed origin suffixes" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) + } + } + + it should "respond with default domains for invalid origins" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) + } + } + + it should "respond with Pragma and Cache-Control (no-cache) headers" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + header("Pragma").map(_.value()) should contain("no-cache") + header[`Cache-Control`].map(_.value()) should contain("no-cache") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala new file mode 100644 index 0000000..cc0019a --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala @@ -0,0 +1,121 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.Connection +import akka.http.scaladsl.server.Directives.{complete => akkaComplete} +import akka.http.scaladsl.server.{Directives, RejectionHandler, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.scalalogging.Logger +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.json.serviceExceptionFormat +import xyz.driver.core.logging.NoLogger +import xyz.driver.core.rest.errors._ + +import scala.concurrent.Future + +class DriverRouteTest + extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { + class TestRoute(override val route: Route) extends DriverRoute { + override def log: Logger = NoLogger + } + + "DriverRoute" should "respond with 200 OK for a basic route" in { + val route = new TestRoute(akkaComplete(StatusCodes.OK)) + + Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.OK + } + } + + it should "respond with a 401 for an InvalidInputException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe InvalidInputException() + } + } + + it should "respond with a 403 for InvalidActionException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.Forbidden + responseAs[ServiceException] shouldBe InvalidActionException() + } + } + + it should "respond with a 404 for ResourceNotFoundException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + responseAs[ServiceException] shouldBe ResourceNotFoundException() + } + } + + it should "respond with a 500 for ExternalServiceException" in { + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe error + } + } + + it should "allow pass-through of external service exceptions" in { + val innerError = InvalidInputException() + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) + val future = Future.failed[String](error) + val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe innerError + } + } + + it should "respond with a 503 for ExternalServiceTimeoutException" in { + val error = ExternalServiceTimeoutException("GET /api/v1/users/") + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.GatewayTimeout + responseAs[ServiceException] shouldBe error + } + } + + it should "respond with a 500 for DatabaseException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe DatabaseException() + } + } + + it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { + val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() + val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) + + Get("/foo") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.OK + headers should contain(Connection("close")) + } + + Get("/bar") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.NotFound + headers should contain(Connection("close")) + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..987717d --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,101 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "handle optional values correctly when old value is null" in { + val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala new file mode 100644 index 0000000..19e4ed1 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala @@ -0,0 +1,151 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.util.ByteString +import org.scalatest.{Matchers, WordSpec} +import xyz.driver.core.rest + +import scala.concurrent.Future +import scala.util.Random + +class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { + "`escapeScriptTags` function" should { + "escape script tags properly" in { + val dirtyString = " + complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "provide a default pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,100") + } + } + "provide default values for a partial pagination" in { + Get("/?pageSize=2") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,2") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=-1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "optional paginated directive" should { + val route: Route = rest.optionalPagination { paginated => + complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "without pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "no pagination") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "completeWithPagination directive" when { + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json.DefaultJsonProtocol._ + + val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString) + val route: Route = + parameter('empty.as[Boolean] ? false) { isEmpty => + completeWithPagination[String] { + case Some(pagination) if isEmpty => + Future.successful(ListResponse(Seq(), 0, Some(pagination))) + case Some(pagination) => + val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize) + Future.successful(ListResponse(filtered, data.size, Some(pagination))) + case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None)) + case None => Future.successful(ListResponse(data, data.size, None)) + } + } + + "pagination is defined" should { + "return a response with pagination headers" in { + Get("/?pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe data.slice(10, 20) + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("11") + } + } + + "disallow pageSize <= 0" in { + Get("/?pageNumber=2&pageSize=0") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=2&pageSize=-1") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "disallow pageNumber <= 0" in { + Get("/?pageNumber=0&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=-1&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + + "pagination is not defined" should { + "return a response with pagination headers and PageCount == 1" in { + Get("/") ~> route ~> check { + responseAs[Seq[String]] shouldBe data + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("1") + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + } +} diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala deleted file mode 100644 index 896bd89..0000000 --- a/src/main/scala/xyz/driver/core/auth.scala +++ /dev/null @@ -1,43 +0,0 @@ -package xyz.driver.core - -import xyz.driver.core.domain.Email -import xyz.driver.core.time.Time -import scalaz.Equal - -object auth { - - trait Permission - - final case class Role(id: Id[Role], name: Name[Role]) { - - def oneOf(roles: Role*): Boolean = roles.contains(this) - - def oneOf(roles: Set[Role]): Boolean = roles.contains(this) - } - - object Role { - implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) - } - - trait User { - def id: Id[User] - } - - final case class AuthToken(value: String) - - final case class AuthTokenUserInfo( - id: Id[User], - email: Email, - emailVerified: Boolean, - audience: String, - roles: Set[Role], - expirationTime: Time) - extends User - - final case class RefreshToken(value: String) - final case class PermissionsToken(value: String) - - final case class PasswordHash(value: String) - - final case class AuthCredentials(identifier: String, password: String) -} diff --git a/src/main/scala/xyz/driver/core/generators.scala b/src/main/scala/xyz/driver/core/generators.scala deleted file mode 100644 index d00b6dd..0000000 --- a/src/main/scala/xyz/driver/core/generators.scala +++ /dev/null @@ -1,143 +0,0 @@ -package xyz.driver.core - -import enumeratum._ -import java.math.MathContext -import java.time.{Instant, LocalDate, ZoneOffset} -import java.util.UUID - -import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} -import xyz.driver.core.date.{Date, DayOfWeek} - -import scala.reflect.ClassTag -import scala.util.Random -import eu.timepit.refined.refineV -import eu.timepit.refined.api.Refined -import eu.timepit.refined.collection._ - -object generators { - - private val random = new Random - import random._ - private val secureRandom = new java.security.SecureRandom() - - private val DefaultMaxLength = 10 - private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet - private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" - private val Numbers = "0123456789" - - private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { - val builder = new StringBuilder - for (_ <- 0 until length) { - builder += chars(secureRandom.nextInt(chars.length)) - } - builder.result() - } - - /** Creates a random invitation token. - * - * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It - * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) - - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) - - def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue - - def nextBoolean(): Boolean = random.nextBoolean() - - def nextDouble(): Double = random.nextDouble() - - def nextId[T](): Id[T] = Id[T](nextUuid().toString) - - def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) - - def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) - - def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) - - def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) - - def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = - NonEmptyName[T](nextNonEmptyString(maxLength)) - - def nextUuid(): UUID = java.util.UUID.randomUUID - - def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) - - def nextString(maxLength: Int = DefaultMaxLength): String = - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - - def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { - refineV[NonEmpty]( - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - ).right.get - } - - def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None - - def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) - - def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) - - def nextInstant(): Instant = Instant.ofEpochMilli(math.abs(nextLong() % System.currentTimeMillis)) - - def nextTime(): Time = nextInstant() - - def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) - - def nextTimeRange(): TimeRange = { - val oneTime = nextTime() - val anotherTime = nextTime() - - TimeRange( - Time(scala.math.min(oneTime.millis, anotherTime.millis)), - Time(scala.math.max(oneTime.millis, anotherTime.millis))) - } - - def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) - - def nextLocalDate(): LocalDate = nextInstant().atZone(ZoneOffset.UTC).toLocalDate - - def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) - - def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = - BigDecimal(multiplier * nextDouble, new MathContext(precision)) - - def oneOf[T](items: T*): T = oneOf(items.toSet) - - def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) - - def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) - - def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = - Array.fill(nextInt(maxLength, minLength))(generator) - - def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = - Seq.fill(nextInt(maxLength, minLength))(generator) - - def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = - Vector.fill(nextInt(maxLength, minLength))(generator) - - def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = - List.fill(nextInt(maxLength, minLength))(generator) - - def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = - seqOf(generator, maxLength, minLength).toSet - - def mapOf[K, V]( - keyGenerator: => K, - valueGenerator: => V, - maxLength: Int = DefaultMaxLength, - minLength: Int = 0): Map[K, V] = - seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap -} diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala deleted file mode 100644 index edc2347..0000000 --- a/src/main/scala/xyz/driver/core/json.scala +++ /dev/null @@ -1,398 +0,0 @@ -package xyz.driver.core - -import java.net.InetAddress -import java.time.format.DateTimeFormatter -import java.time.{Instant, LocalDate} -import java.util.{TimeZone, UUID} - -import akka.http.scaladsl.unmarshalling.Unmarshaller -import com.neovisionaries.i18n.{CountryCode, CurrencyCode} -import enumeratum._ -import eu.timepit.refined.api.{Refined, Validate} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import spray.json._ -import xyz.driver.core.auth.AuthCredentials -import xyz.driver.core.date.{Date, DayOfWeek, Month} -import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} -import xyz.driver.core.rest.errors._ -import xyz.driver.core.time.{Time, TimeOfDay} - -import scala.reflect.runtime.universe._ -import scala.reflect.{ClassTag, classTag} -import scala.util.Try -import scala.util.control.NonFatal - -object json extends PathMatchers with Unmarshallers { - import DefaultJsonProtocol._ - - implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { - def write(id: Id[T]) = JsString(id.value) - - def read(value: JsValue): Id[T] = value match { - case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) - case JsString(id) => Id[T](id) - case _ => throw DeserializationException("Id expects string") - } - } - - implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F], convert: F => F @@ T = null): JsonFormat[F @@ T] = - new JsonFormat[F @@ T] { - import tagging._ - - private val transformReadValue = Option(convert).getOrElse((_: F).tagged[T]) - - override def write(obj: F @@ T): JsValue = underlying.write(obj) - - override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) - } - - implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { - def write(name: Name[T]) = JsString(name.value) - - def read(value: JsValue): Name[T] = value match { - case JsString(name) => Name[T](name) - case _ => throw DeserializationException("Name expects string") - } - } - - implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { - def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) - - def read(value: JsValue): Time = Time(instantFormat.read(value)) - } - - implicit val instantFormat: JsonFormat[Instant] = new JsonFormat[Instant] { - def write(instant: Instant): JsValue = JsString(instant.toString) - - def read(value: JsValue): Instant = value match { - case JsObject(fields) => - fields - .get("timestamp") - .flatMap { - case JsNumber(millis) => Some(Instant.ofEpochMilli(millis.longValue())) - case _ => None - } - .getOrElse(deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}")) - case JsNumber(millis) => Instant.ofEpochMilli(millis.longValue()) - case JsString(str) => - try Instant.parse(str) - catch { case NonFatal(_) => deserializationError(s"Instant expects ISO timestamp but got $str") } - case _ => deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}") - } - } - - implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { - private val formatter = TimeOfDay.getFormatter - def read(json: JsValue): java.time.LocalTime = json match { - case JsString(chars) => - java.time.LocalTime.parse(chars) - case _ => deserializationError(s"Expected time string got ${json.toString}") - } - - def write(obj: java.time.LocalTime): JsValue = { - JsString(obj.format(formatter)) - } - } - - implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { - override def write(obj: TimeZone): JsValue = { - JsString(obj.getID()) - } - - override def read(json: JsValue): TimeZone = json match { - case JsString(chars) => - java.util.TimeZone.getTimeZone(chars) - case _ => deserializationError(s"Expected time zone string got ${json.toString}") - } - } - - implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) - - implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) - - implicit val dateFormat = new RootJsonFormat[Date] { - def write(date: Date) = JsString(date.toString) - def read(value: JsValue): Date = value match { - case JsString(dateString) => - Date - .fromString(dateString) - .getOrElse( - throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) - case _ => throw DeserializationException(s"Date expects a string, but got $value.") - } - } - - implicit val localDateFormat = new RootJsonFormat[LocalDate] { - val format = DateTimeFormatter.ISO_LOCAL_DATE - - def write(date: LocalDate): JsValue = JsString(date.format(format)) - def read(value: JsValue): LocalDate = value match { - case JsString(dateString) => - try LocalDate.parse(dateString, format) - catch { - case NonFatal(_) => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.") - } - case _ => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got ${value.compactPrint}.") - } - } - - implicit val monthFormat = new RootJsonFormat[Month] { - def write(month: Month) = JsNumber(month) - def read(value: JsValue): Month = value match { - case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) - case _ => throw DeserializationException("Expected a number from 0 to 11") - } - } - - implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { - def write(revision: Revision[T]) = JsString(revision.id.toString) - - def read(value: JsValue): Revision[T] = value match { - case JsString(revision) => Revision[T](revision) - case _ => throw DeserializationException("Revision expects uuid string") - } - } - - implicit val base64Format = new RootJsonFormat[Base64] { - def write(base64Value: Base64) = JsString(base64Value.value) - - def read(value: JsValue): Base64 = value match { - case JsString(base64Value) => Base64(base64Value) - case _ => throw DeserializationException("Base64 format expects string") - } - } - - implicit val emailFormat = new RootJsonFormat[Email] { - def write(email: Email) = JsString(email.username + "@" + email.domain) - def read(json: JsValue): Email = json match { - - case JsString(value) => - Email.parse(value).getOrElse { - deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) - } - - case _ => - deserializationError("Expected string as Email, but got " + json.toString) - } - } - - implicit object phoneNumberFormat extends RootJsonFormat[PhoneNumber] { - - private val basicFormat = jsonFormat3(PhoneNumber.apply) - - def write(obj: PhoneNumber): JsValue = basicFormat.write(obj) - - def read(json: JsValue): PhoneNumber = { - val maybePhone = json match { - case JsString(number) => PhoneNumber.parse(number) - case obj: JsObject => PhoneNumber.parse(basicFormat.read(obj).toString) - case _ => None - } - maybePhone.getOrElse(deserializationError("Invalid phone number")) - } - } - - implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { - override def read(json: JsValue): AuthCredentials = { - json match { - case JsObject(fields) => - val emailField = fields.get("email") - val identifierField = fields.get("identifier") - val passwordField = fields.get("password") - - (emailField, identifierField, passwordField) match { - case (_, _, None) => - deserializationError("password field must be set") - case (Some(JsString(em)), _, Some(JsString(pw))) => - val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) - AuthCredentials(email.toString, pw) - case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) - case (None, None, _) => deserializationError("identifier must be provided") - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - } - - override def write(obj: AuthCredentials): JsValue = JsObject( - "identifier" -> JsString(obj.identifier), - "password" -> JsString(obj.password) - ) - } - - implicit object inetAddressFormat extends JsonFormat[InetAddress] { - override def read(json: JsValue): InetAddress = json match { - case JsString(ipString) => - Try(InetAddress.getByName(ipString)) - .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) - case _ => deserializationError(s"Expected string for IP Address, got $json") - } - - override def write(obj: InetAddress): JsValue = - JsString(obj.getHostAddress) - } - - implicit val countryCodeFormat: JsonFormat[CountryCode] = javaEnumFormat[CountryCode] - - implicit val currencyCodeFormat: JsonFormat[CurrencyCode] = javaEnumFormat[CurrencyCode] - - object enumeratum { - - def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - - trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => - - implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) - - implicit val unmarshaller: Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - } - - class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { - override def read(json: JsValue): T = json match { - case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - - override def write(obj: T): JsValue = JsString(obj.entryName) - } - - private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = - deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") - } - - class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { - private val map = mapping.toMap - - override def write(value: T): JsValue = { - map.find(_._2 == value).map(_._1) match { - case Some(name) => JsString(name) - case _ => serializationError(s"Value $value is not found in the mapping $map") - } - } - - override def read(json: JsValue): T = json match { - case JsString(name) => - map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - } - - def javaEnumFormat[T <: java.lang.Enum[_]: ClassTag]: JsonFormat[T] = { - val values = classTag[T].runtimeClass.asInstanceOf[Class[T]].getEnumConstants - new EnumJsonFormat[T](values.map(v => v.name() -> v): _*) - } - - class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { - def write(valueClass: T) = JsNumber(writeValue(valueClass)) - def read(json: JsValue): T = json match { - case JsNumber(value) => create(value) - case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - class GadtJsonFormat[T: TypeTag]( - typeField: String, - typeValue: PartialFunction[T, String], - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) - extends RootJsonFormat[T] { - - def write(value: T): JsValue = { - - val valueType = typeValue.applyOrElse(value, { v: T => - deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) - }) - - val valueFormat = - jsonFormat.applyOrElse(valueType, { f: String => - deserializationError(s"No Json format for this type of $valueType") - }) - - valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { - case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) - case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") - } - } - - def read(json: JsValue): T = json match { - case JsObject(fields) => - val valueJson = JsObject(fields.filterNot(_._1 == typeField)) - fields(typeField) match { - case JsString(valueType) => - val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - }) - valueFormat.read(valueJson) - case _ => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - } - case _ => - deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - object GadtJsonFormat { - - def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { - - new GadtJsonFormat[T](typeField, typeValue, jsonFormat) - } - } - - /** - * Provides the JsonFormat for the Refined types provided by the Refined library. - * - * @see https://github.com/fthomas/refined - */ - implicit def refinedJsonFormat[T, Predicate]( - implicit valueFormat: JsonFormat[T], - validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = - new JsonFormat[Refined[T, Predicate]] { - def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) - def read(value: JsValue): T Refined Predicate = { - refineV[Predicate](valueFormat.read(value))(validate) match { - case Right(refinedValue) => refinedValue - case Left(refinementError) => deserializationError(refinementError) - } - } - } - - implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = - new RootJsonFormat[NonEmptyName[T]] { - def write(name: NonEmptyName[T]) = JsString(name.value.value) - - def read(value: JsValue): NonEmptyName[T] = - NonEmptyName[T](nonEmptyStringFormat.read(value)) - } - - implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = - GadtJsonFormat.create[ServiceException]("type") { - case _: InvalidInputException => "InvalidInputException" - case _: InvalidActionException => "InvalidActionException" - case _: UnauthorizedException => "UnauthorizedException" - case _: ResourceNotFoundException => "ResourceNotFoundException" - case _: ExternalServiceException => "ExternalServiceException" - case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" - case _: DatabaseException => "DatabaseException" - } { - case "InvalidInputException" => jsonFormat(InvalidInputException, "message") - case "InvalidActionException" => jsonFormat(InvalidActionException, "message") - case "UnauthorizedException" => jsonFormat(UnauthorizedException, "message") - case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") - case "ExternalServiceException" => - jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") - case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") - case "DatabaseException" => jsonFormat(DatabaseException, "message") - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala b/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala deleted file mode 100644 index 87946e4..0000000 --- a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core -package rest - -class DnsDiscovery(transport: HttpRestServiceTransport, overrides: Map[String, String]) { - - def discover[A](implicit descriptor: ServiceDescriptor[A]): A = { - val url = overrides.getOrElse(descriptor.name, s"https://${descriptor.name}") - descriptor.connect(transport, url) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala deleted file mode 100644 index 911e306..0000000 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ /dev/null @@ -1,122 +0,0 @@ -package xyz.driver.core.rest - -import java.sql.SQLException - -import akka.http.scaladsl.model.headers.CacheDirectives.`no-cache` -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{StatusCodes, _} -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import com.typesafe.scalalogging.Logger -import org.slf4j.MDC -import xyz.driver.core.rest -import xyz.driver.core.rest.errors._ - -import scala.compat.Platform.ConcurrentModificationException - -trait DriverRoute { - def log: Logger - - def route: Route - - def routeWithDefaults: Route = { - (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { - route - } - } - - protected def defaultResponseHeaders: Directive0 = { - extractRequest flatMap { request => - // Needs to happen before any request processing, so all the log messages - // associated with processing of this request are having this `trackingId` - val trackingId = rest.extractTrackingId(request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - MDC.put("trackingId", trackingId) - - respondWithHeaders(tracingHeader +: DriverRoute.DefaultHeaders: _*) - } - } - - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - case serviceException: ServiceException => - serviceExceptionHandler(serviceException) - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( - ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) - - case t: Exception => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) - } - - protected def serviceExceptionHandler(serviceException: ServiceException): Route = { - val statusCode = serviceException match { - case e: InvalidInputException => - log.info("Invalid client input error", e) - StatusCodes.BadRequest - case e: InvalidActionException => - log.info("Invalid client action error", e) - StatusCodes.Forbidden - case e: UnauthorizedException => - log.info("Unauthorized user error", e) - StatusCodes.Unauthorized - case e: ResourceNotFoundException => - log.info("Resource not found error", e) - StatusCodes.NotFound - case e: ExternalServiceException => - log.error("Error while calling another service", e) - StatusCodes.InternalServerError - case e: ExternalServiceTimeoutException => - log.error("Service timeout error", e) - StatusCodes.GatewayTimeout - case e: DatabaseException => - log.error("Database error", e) - StatusCodes.InternalServerError - } - - { (ctx: RequestContext) => - import xyz.driver.core.json.serviceExceptionFormat - val entity = - HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) - errorResponse(statusCode, entity, serviceException)(ctx) - } - } - - protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = - errorResponse(statusCode, HttpEntity(message), exception) - - protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { - complete(HttpResponse(statusCode, entity = entity)) - } - -} - -object DriverRoute { - val DefaultHeaders: List[HttpHeader] = List( - // This header will eliminate the risk of envoy trying to reuse a connection - // that already timed out on the server side by completely rejecting keep-alive - Connection("close"), - // These 2 headers are the simplest way to prevent IE from caching GET requests - RawHeader("Pragma", "no-cache"), - `Cache-Control`(List(`no-cache`(Nil))) - ) -} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala deleted file mode 100644 index e31635b..0000000 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ /dev/null @@ -1,103 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.RawHeader -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.TcpIdleTimeoutException -import org.slf4j.MDC -import xyz.driver.core.Name -import xyz.driver.core.reporting.Reporter -import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -class HttpRestServiceTransport( - applicationName: Name[App], - applicationVersion: String, - val actorSystem: ActorSystem, - val executionContext: ExecutionContext, - reporter: Reporter) - extends ServiceTransport { - - protected implicit val execution: ExecutionContext = executionContext - - protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - val tags = Map( - // open tracing semantic tags - "span.kind" -> "client", - "service" -> applicationName.value, - "http.url" -> requestStub.uri.toString, - "http.method" -> requestStub.method.value, - "peer.hostname" -> requestStub.uri.authority.host.toString, - // google's tracing console provides extra search features if we define these tags - "/http/path" -> requestStub.uri.path.toString, - "/http/method" -> requestStub.method.value.toString, - "/http/url" -> requestStub.uri.toString - ) - reporter.traceAsync(s"http_call_rpc", tags) { implicit span => - val requestTime = System.currentTimeMillis() - - val request = requestStub - .withHeaders(context.contextHeaders.toSeq.map { - case (ContextHeaders.TrackingIdHeader, _) => - RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) - case (ContextHeaders.StacktraceHeader, _) => - RawHeader( - ContextHeaders.StacktraceHeader, - Option(MDC.get("stack")) - .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) - .getOrElse("")) - case (header, headerValue) => RawHeader(header, headerValue) - }: _*) - - reporter.debug(s"Sending request to ${request.method} ${request.uri}") - - val response = httpClient.makeRequest(request) - - response.onComplete { - case Success(r) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.debug( - s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") - - case Failure(t: Throwable) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.warn( - s"Failed to receive response from ${request.method.value} ${request.uri} in $responseLatency ms", - t) - }(executionContext) - - response.recoverWith { - case _: TcpIdleTimeoutException => - val serviceCalled = s"${requestStub.method.value} ${requestStub.uri}" - Future.failed(ExternalServiceTimeoutException(serviceCalled)) - case t: Throwable => Future.failed(t) - } - }(context.spanContext) - } - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { - - sendRequestGetResponse(context)(requestStub) flatMap { response => - if (response.status == StatusCodes.NotFound) { - Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) - } else if (response.status.isFailure()) { - val serviceCalled = s"${requestStub.method} ${requestStub.uri}" - Unmarshal(response.entity).to[String] flatMap { errorString => - import spray.json._ - import xyz.driver.core.json._ - val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption - Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) - } - } else { - Future.successful(Unmarshal(response.entity)) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala deleted file mode 100644 index f33bf9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala +++ /dev/null @@ -1,104 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.javadsl.server.Rejections -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} -import spray.json._ - -import scala.concurrent.Future -import scala.util.{Failure, Success, Try} - -trait PatchDirectives extends Directives with SprayJsonSupport { - - /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - val `application/merge-patch+json`: MediaType.WithFixedCharset = - MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) - - /** Wraps a JSON value that represents a patch. - * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - case class PatchValue(value: JsValue) { - - /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ - def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) - } - - /** Witness that the given patch may be applied to an original domain value. - * @tparam A type of the domain value - * @param patch the patch that may be applied to a domain value - * @param format a JSON format that enables serialization and deserialization of a domain value */ - case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { - - /** Applies the patch to a given domain object. The result will be a combination - * of the original value, updates with the fields specified in this witness' patch. */ - def applyTo(original: A): A = { - val serialized = format.write(original) - val merged = patch.applyTo(serialized) - val deserialized = format.read(merged) - deserialized - } - } - - implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = - Unmarshaller.byteStringUnmarshaller - .andThen(sprayJsValueByteStringUnmarshaller) - .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) - .map(js => PatchValue(js)) - - implicit def patchableUnmarshaller[A]( - implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], - format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { - patchUnmarshaller.map(patch => Patchable[A](patch, format)) - } - - protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { - JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => - val oldValue = oldObj.fields.getOrElse(key, JsNull) - val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) - key -> newValue - })(collection.breakOut): _*) - } - - protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { - def mergeError(typ: String): Nothing = - deserializationError(s"Expected $typ value, got $newValue") - - if (maxLevels.exists(_ < 0)) oldValue - else { - (oldValue, newValue) match { - case (_: JsString, newString @ (JsString(_) | JsNull)) => newString - case (_: JsString, _) => mergeError("string") - case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber - case (_: JsNumber, _) => mergeError("number") - case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool - case (_: JsBoolean, _) => mergeError("boolean") - case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr - case (_: JsArray, _) => mergeError("array") - case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) - case (_: JsObject, JsNull) => JsNull - case (_: JsObject, _) => mergeError("object") - case (JsNull, _) => newValue - } - } - } - - def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = - Directive { inner => requestCtx => - onSuccess(retrieve)({ - case Some(oldT) => - Try(patchable.applyTo(oldT)) - .transform[Route]( - mergedT => scala.util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => reject() - })(requestCtx) - } -} - -object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala deleted file mode 100644 index 2854257..0000000 --- a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala +++ /dev/null @@ -1,67 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.scaladsl.{Keep, Sink, Source} -import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} -import xyz.driver.core.Name - -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.concurrent.duration._ -import scala.util.{Failure, Success} - -class PooledHttpClient( - baseUri: Uri, - applicationName: Name[App], - applicationVersion: String, - requestRateLimit: Int = 64, - requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) - extends HttpClient { - - private val host = baseUri.authority.host.toString() - private val port = baseUri.effectivePort - private val scheme = baseUri.scheme - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - private val pool = if (scheme.equalsIgnoreCase("https")) { - Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } else { - Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } - - private val queue = Source - .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) - .via(pool) - .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) - .toMat(Sink.foreach({ - case ((Success(resp), p)) => p.success(resp) - case ((Failure(e), p)) => p.failure(e) - }))(Keep.left) - .run - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - val responsePromise = Promise[HttpResponse]() - - queue.offer(request -> responsePromise).flatMap { - case QueueOfferResult.Enqueued => - responsePromise.future - case QueueOfferResult.Dropped => - Future.failed(new Exception(s"Request queue to the host $host is overflown")) - case QueueOfferResult.Failure(ex) => - Future.failed(ex) - case QueueOfferResult.QueueClosed => - Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala deleted file mode 100644 index c0e9f99..0000000 --- a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala +++ /dev/null @@ -1,26 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} -import com.typesafe.config.Config -import xyz.driver.core.Name - -import scala.concurrent.ExecutionContext - -trait ProxyRoute extends DriverRoute { - implicit val executionContext: ExecutionContext - val config: Config - val httpClient: HttpClient - - protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => - val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") - val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") - - val originalUri = ctx.request.uri - val originalRequest = ctx.request - - val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) - val newRequest = originalRequest.withUri(newUri) - - httpClient.makeRequest(newRequest).map(RouteResult.Complete) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/RestService.scala b/src/main/scala/xyz/driver/core/rest/RestService.scala deleted file mode 100644 index 09d98b8..0000000 --- a/src/main/scala/xyz/driver/core/rest/RestService.scala +++ /dev/null @@ -1,86 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model._ -import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} -import akka.stream.Materializer - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.{ListT, OptionT} - -trait RestService extends Service { - - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json._ - - protected implicit val exec: ExecutionContext - protected implicit val materializer: Materializer - - implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { - def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = - if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] - } - - protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = - OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) - - protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = - OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) - - protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = - ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) - - protected def jsonEntity(json: JsValue): RequestEntity = - HttpEntity(ContentTypes.`application/json`, json.compactPrint) - - protected def mergePatchJsonEntity(json: JsValue): RequestEntity = - HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) - - protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) - - protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) - - protected def postJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) - - protected def putJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) - - protected def patchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) - - protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) - - protected def endpointUri(baseUri: Uri, path: String): Uri = - baseUri.withPath(Uri.Path(path)) - - protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = - baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) - - protected def responseToListResponse[T: JsonFormat](pagination: Option[Pagination])( - response: HttpResponse): Future[ListResponse[T]] = { - import DefaultJsonProtocol._ - val resourceCount = response.headers - .find(_.name() equalsIgnoreCase ContextHeaders.ResourceCount) - .map(_.value().toInt) - .getOrElse(0) - val meta = ListResponse.Meta(resourceCount, pagination.getOrElse(Pagination(resourceCount max 1, 1))) - Unmarshal(response.entity).to[List[T]].map(ListResponse(_, meta)) - } - - protected def responseToListResponse[T: JsonFormat](pagination: Pagination)( - response: HttpResponse): Future[ListResponse[T]] = responseToListResponse(Some(pagination))(response) -} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala b/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala deleted file mode 100644 index 646fae8..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala +++ /dev/null @@ -1,16 +0,0 @@ -package xyz.driver.core -package rest -import scala.annotation.implicitNotFound - -@implicitNotFound( - "Don't know how to communicate with service ${S}. Make sure an implicit ServiceDescriptor is" + - "available. A good place to put one is in the service's companion object.") -trait ServiceDescriptor[S] { - - /** The service's name. Must be unique among all services. */ - def name: String - - /** Get an instance of the service. */ - def connect(transport: HttpRestServiceTransport, url: String): S - -} diff --git a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala deleted file mode 100644 index 964a5a2..0000000 --- a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala +++ /dev/null @@ -1,29 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.ActorMaterializer -import xyz.driver.core.Name - -import scala.concurrent.Future - -class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) - extends HttpClient { - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - private val client = Http()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - client.singleRequest(request, settings = connectionPoolSettings) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala deleted file mode 100644 index 5ceac54..0000000 --- a/src/main/scala/xyz/driver/core/rest/Swagger.scala +++ /dev/null @@ -1,144 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} -import akka.http.scaladsl.server.Route -import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile -import akka.stream.ActorAttributes -import akka.stream.scaladsl.{Framing, StreamConverters} -import akka.util.ByteString -import com.github.swagger.akka.SwaggerHttpService -import com.github.swagger.akka.model._ -import com.typesafe.config.Config -import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme -import io.swagger.models.auth.{ApiKeyAuthDefinition, In} -import io.swagger.util.Json - -import scala.util.control.NonFatal - -class Swagger( - override val host: String, - accessSchemes: List[String], - version: String, - override val apiClasses: Set[Class[_]], - val config: Config, - val logger: Logger) - extends SwaggerHttpService { - - override val schemes = accessSchemes.map { s => - Scheme.forValue(s) - } - - // Note that the reason for overriding this is a subtle chain of causality: - // - // 1. Some of our endpoints require a single trailing slash and will not - // function if it is omitted - // 2. Swagger omits trailing slashes in its generated api doc - // 3. To work around that, a space is added after the trailing slash in the - // swagger Path annotations - // 4. This space is removed manually in the code below - // - // TODO: Ideally we'd like to drop this custom override and fix the issue in - // 1, by dropping the slash requirement and accepting api endpoints with and - // without trailing slashes. This will require inspecting and potentially - // fixing all service endpoints. - override def generateSwaggerJson: String = { - import io.swagger.models.{Swagger => JSwagger} - - import scala.collection.JavaConverters._ - try { - val swagger: JSwagger = reader.read(apiClasses.asJava) - - val paths = if (swagger.getPaths == null) { - Map.empty - } else { - swagger.getPaths.asScala - } - - // Removing trailing spaces - val fixedPaths = paths.map { - case (key, path) => - key.trim -> path - } - - swagger.setPaths(fixedPaths.asJava) - - Json.pretty().writeValueAsString(swagger) - } catch { - case NonFatal(t) => - logger.error("Issue with creating swagger.json", t) - throw t - } - } - - override val securitySchemeDefinitions = Map( - "token" -> { - val definition = new ApiKeyAuthDefinition("Authorization", In.HEADER) - definition.setDescription("Authentication token") - definition - } - ) - - override val basePath: String = config.getString("swagger.basePath") - override val apiDocsPath: String = config.getString("swagger.docsPath") - - override val info = Info( - config.getString("swagger.apiInfo.description"), - version, - config.getString("swagger.apiInfo.title"), - config.getString("swagger.apiInfo.termsOfServiceUrl"), - contact = Some( - Contact( - config.getString("swagger.apiInfo.contact.name"), - config.getString("swagger.apiInfo.contact.url"), - config.getString("swagger.apiInfo.contact.email") - )), - license = Some( - License( - config.getString("swagger.apiInfo.license"), - config.getString("swagger.apiInfo.licenseUrl") - )), - vendorExtensions = Map.empty[String, AnyRef] - ) - - /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ - private def getTemplatedResource( - resourceName: String, - contentType: ContentType, - substitution: (String, String)): Route = get { - Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { - case Some(ResourceFile(url, length @ _, _)) => - extractSettings { settings => - val stream = StreamConverters - .fromInputStream(() => url.openStream()) - .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) - .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) - .map { line => - line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) - } - .map(line => ByteString(line + "\n")) - complete( - HttpEntity(contentType, stream) - ) - } - case None => reject - } - } - - def swaggerUI: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui") - - def swaggerUINew: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui-dist/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui-dist") - -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala deleted file mode 100644 index 5007774..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val permissionsMap = permissions.map(_ -> true).toMap - Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala deleted file mode 100644 index e1a94e1..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala +++ /dev/null @@ -1,75 +0,0 @@ -package xyz.driver.core.rest.auth - -import akka.http.scaladsl.server.directives.Credentials -import com.typesafe.scalalogging.Logger -import scalaz.OptionT -import xyz.driver.core.auth.{AuthToken, Permission, User} -import xyz.driver.core.rest.errors.{ExternalServiceException, UnauthorizedException} -import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ContextHeaders, ServiceRequestContext, serviceContext} - -import scala.concurrent.{ExecutionContext, Future} - -abstract class AuthProvider[U <: User]( - val authorization: Authorization[U], - log: Logger, - val realm: String -)(implicit execution: ExecutionContext) { - - import akka.http.scaladsl.server._ - import Directives.{authorize => akkaAuthorize, _} - - def this(authorization: Authorization[U], log: Logger)(implicit executionContext: ExecutionContext) = - this(authorization, log, "driver.xyz") - - /** - * Specific implementation on how to extract user from request context, - * can either need to do a network call to auth server or extract everything from self-contained token - * - * @param ctx set of request values which can be relevant to authenticate user - * @return authenticated user - */ - def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] - - protected def authenticator(context: ServiceRequestContext): AsyncAuthenticator[U] = { - case Credentials.Missing => - log.info(s"Request (${context.trackingId}) missing authentication credentials") - Future.successful(None) - case Credentials.Provided(authToken) => - authenticatedUser(context.withAuthToken(AuthToken(authToken))).run.recover({ - case ExternalServiceException(_, _, Some(UnauthorizedException(_))) => None - }) - } - - /** - * Verifies that a user agent is properly authenticated, and (optionally) authorized with the specified permissions - */ - def authorize( - context: ServiceRequestContext, - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authenticateOAuth2Async[U](realm, authenticator(context)) flatMap { authenticatedUser => - val authCtx = context.withAuthenticatedUser(context.authToken.get, authenticatedUser) - onSuccess(authorization.userHasPermissions(authenticatedUser, permissions)(authCtx)) flatMap { - case AuthorizationResult(authorized, token) => - val allAuthorized = permissions.forall(authorized.getOrElse(_, false)) - akkaAuthorize(allAuthorized) tflatMap { _ => - val cachedPermissionsCtx = token.fold(authCtx)(authCtx.withPermissionsToken) - provide(cachedPermissionsCtx) - } - } - } - } - - /** - * Verifies if request is authenticated and authorized to have `permissions` - */ - def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - serviceContext flatMap (authorize(_, permissions: _*)) - } -} - -object AuthProvider { - val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader: String = "set-authorization" - val SetPermissionsTokenHeader: String = "set-permissions" -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala deleted file mode 100644 index 1a5e9be..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -trait Authorization[U <: User] { - def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala deleted file mode 100644 index efe28c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala +++ /dev/null @@ -1,22 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, PermissionsToken} - -import scalaz.Scalaz.mapMonoid -import scalaz.Semigroup -import scalaz.syntax.semigroup._ - -final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) -object AuthorizationResult { - val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) - - implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { - private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) - private implicit val permissionsTokenSemigroup = - Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) - - override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { - AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala deleted file mode 100644 index 66de4ef..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala +++ /dev/null @@ -1,55 +0,0 @@ -package xyz.driver.core.rest.auth - -import java.nio.file.{Files, Path} -import java.security.{KeyFactory, PublicKey} -import java.security.spec.X509EncodedKeySpec - -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future -import scalaz.syntax.std.boolean._ - -class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - import spray.json._ - - def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = - tokenObject.fields.get("permissions").collect { - case JsObject(fields) => - fields.collect { - case (key, JsBoolean(value)) => key -> value - } - } - - val result = for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) - - permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) - - authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap - } yield AuthorizationResult(authorized, Some(token)) - - Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) - } -} - -object CachedTokenAuthorization { - def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { - lazy val publicKey: PublicKey = { - val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim - val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) - val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) - KeyFactory.getInstance("RSA").generatePublic(spec) - } - new CachedTokenAuthorization[U](publicKey, issuer) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala deleted file mode 100644 index 131e7fc..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.Scalaz.{futureInstance, listInstance} -import scalaz.syntax.semigroup._ -import scalaz.syntax.traverse._ - -class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) - extends Authorization[U] { - - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = - permissions.forall(permissionsMap.getOrElse(_, false)) - - authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { - (authResult, authorization) => - if (allAuthorized(authResult.authorized)) Future.successful(authResult) - else { - authorization.userHasPermissions(user, permissions).map(authResult |+| _) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala deleted file mode 100644 index ff3424d..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala +++ /dev/null @@ -1,19 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.auth.AuthProvider - -/** Authentication and authorization directives. */ -trait AuthDirectives extends AkkaDirectives { - - /** Authenticate a user based on service request headers and check if they have all given permissions. */ - def authenticateAndAuthorize[U <: User]( - authProvider: AuthProvider[U], - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authProvider.authorize(permissions: _*) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala deleted file mode 100644 index 5a6bbfd..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala +++ /dev/null @@ -1,72 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.model.HttpMethods._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{HttpResponse, StatusCodes} -import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} - -/** Directives to handle Cross-Origin Resource Sharing (CORS). */ -trait CorsDirectives extends AkkaDirectives { - - /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request - * origin. - * - * In a microservice environment, it can be difficult to know in advance the exact origin - * from which requests may be issued [1]. For example, the request may come from a web page served from - * any of the services, on any namespace or from other documentation sites. In general, only a set - * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that - * implement CORS require exact specification of allowed origins, including full host name and scheme, - * in order to send credentials and headers with requests to other origins. - * - * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by - * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. - * - * Note that the simplicity of this implementation comes with two notable drawbacks: - * - * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. - * - * - Allowed methods and headers can not be customized on a per-request basis. All standard - * HTTP methods are allowed, and allowed headers are specified for all inner routes. - * - * This handler is not suited for cases where more fine-grained control of responses is required. - * - * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through - * a common gateway. - * - * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed - * origins. - * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. - * @param inner Route into which CORS headers will be injected. - */ - def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { - optionalHeaderValueByType[Origin](()) { maybeOrigin => - val allowedOrigins: HttpOriginRange = maybeOrigin match { - // Note that this is not a security issue: the client will never send credentials if the allowed - // origin is set to *. This case allows us to deal with clients that do not send an origin header. - case None => HttpOriginRange.* - case Some(requestOrigin) => - val allowedOrigin = requestOrigin.origins.find(origin => - allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) - allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) - } - - respondWithHeaders( - `Access-Control-Allow-Origin`.forRange(allowedOrigins), - `Access-Control-Allow-Credentials`(true), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*) - ) { - options { // options is used during preflight check - complete( - HttpResponse(StatusCodes.OK) - .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) - } ~ inner // in case of non-preflight check we don't do anything special - } - } - } - -} - -object CorsDirectives extends CorsDirectives diff --git a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala deleted file mode 100644 index 0cd4ef1..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala +++ /dev/null @@ -1,6 +0,0 @@ -package xyz.driver.core -package rest -package directives - -trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers -object Directives extends Directives diff --git a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala deleted file mode 100644 index 218c9ae..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala +++ /dev/null @@ -1,85 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.time.Instant -import java.util.UUID - -import akka.http.scaladsl.model.Uri.Path -import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} -import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import xyz.driver.core.domain.PhoneNumber -import xyz.driver.core.time.Time - -import scala.util.control.NonFatal - -/** Akka-HTTP path matchers for custom core types. */ -trait PathMatchers { - - private def UuidInPath[T]: PathMatcher1[Id[T]] = - AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) - - def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) - case _ => Unmatched - } - } - - def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) - case _ => Unmatched - } - } - - private def timestampInPath: PathMatcher1[Long] = - PathMatcher("""[+-]?\d*""".r) flatMap { string => - try Some(string.toLong) - catch { case _: IllegalArgumentException => None } - } - - def InstantInPath: PathMatcher1[Instant] = - new PathMatcher1[Instant] { - def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { - case Path.Segment(head, tail) => - try Matched(tail, Tuple1(Instant.parse(head))) - catch { - case NonFatal(_) => Unmatched - } - case _ => Unmatched - } - } | timestampInPath.map(Instant.ofEpochMilli) - - def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) - - def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - refineV[NonEmpty](segment) match { - case Left(_) => Unmatched - case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) - } - case _ => Unmatched - } - } - - def RevisionInPath[T]: PathMatcher1[Revision[T]] = - PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => - Some(Revision[T](string)) - } - - def PhoneInPath: PathMatcher1[PhoneNumber] = new PathMatcher1[PhoneNumber] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - PhoneNumber - .parse(segment) - .map(parsed => Matched(tail, Tuple1(parsed))) - .getOrElse(Unmatched) - case _ => Unmatched - } - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala deleted file mode 100644 index 6c45d15..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala +++ /dev/null @@ -1,40 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.util.UUID - -import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} -import akka.http.scaladsl.unmarshalling.Unmarshaller -import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} - -/** Akka-HTTP unmarshallers for custom core types. */ -trait Unmarshallers { - - implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = - Unmarshaller.strict[String, Id[A]] { str => - Id[A](UUID.fromString(str).toString) - } - - implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = - Unmarshaller.firstOf( - Unmarshaller.strict((JsString(_: String)) andThen reader.read), - stringToValueUnmarshaller[T] - ) - - implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = - Unmarshaller.strict[String, Revision[T]](Revision[T]) - - val jsValueToStringMarshaller: Marshaller[JsValue, String] = - Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) - - def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = - jsValueToStringMarshaller.compose[T](jsonFormat.write) - - val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = - Unmarshaller.strict[String, JsValue](value => JsonParser(value)) - - def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = - stringToJsValueUnmarshaller.map[T](jsonFormat.read) - -} diff --git a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala deleted file mode 100644 index f2962c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.errors - -sealed abstract class ServiceException(val message: String) extends Exception(message) - -final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) - -final case class InvalidActionException(override val message: String = "This action is not allowed") - extends ServiceException(message) - -final case class UnauthorizedException( - override val message: String = "The user's authentication credentials are invalid or missing") - extends ServiceException(message) - -final case class ResourceNotFoundException(override val message: String = "Resource not found") - extends ServiceException(message) - -final case class ExternalServiceException( - serviceName: String, - serviceMessage: String, - serviceException: Option[ServiceException]) - extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") - -final case class ExternalServiceTimeoutException(serviceName: String) - extends ServiceException(s"$serviceName took too long to respond") - -final case class DatabaseException(override val message: String = "Database access error") - extends ServiceException(message) diff --git a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala deleted file mode 100644 index 866476d..0000000 --- a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala +++ /dev/null @@ -1,33 +0,0 @@ -package xyz.driver.core -package rest -package headers - -import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} -import xyz.driver.core.reporting.SpanContext - -import scala.util.Try - -/** Encapsulates a trace context in an HTTP header for propagation across services. - * - * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) - * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is - * ignored. - */ -final case class Traceparent(spanContext: SpanContext) extends ModeledCustomHeader[Traceparent] { - override def renderInRequests = true - override def renderInResponses = true - override val companion: Traceparent.type = Traceparent - override def value: String = f"01-${spanContext.traceId}-${spanContext.spanId}-00" -} -object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { - override val name = "traceparent" - override def parse(value: String) = Try { - val Array(version, traceId, spanId, _) = value.split("-") - require( - version == "01", - s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") - new Traceparent( - new SpanContext(traceId, spanId) - ) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala deleted file mode 100644 index 34a4a9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ /dev/null @@ -1,323 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.Flow -import akka.util.ByteString -import scalaz.Scalaz.{intInstance, stringInstance} -import scalaz.syntax.equal._ -import scalaz.{Functor, OptionT} -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.errors.ExternalServiceException -import xyz.driver.core.rest.headers.Traceparent -import xyz.driver.tracing.TracingDirectives - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.Try - -trait Service - -object Service - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] -} - -sealed trait SortingOrder -object SortingOrder { - case object Asc extends SortingOrder - case object Desc extends SortingOrder -} - -final case class SortingField(name: String, sortingOrder: SortingOrder) -final case class Sorting(sortingFields: Seq[SortingField]) - -final case class Pagination(pageSize: Int, pageNumber: Int) { - require(pageSize > 0, "Page size must be greater than zero") - require(pageNumber > 0, "Page number must be greater than zero") - - def offset: Int = pageSize * (pageNumber - 1) -} - -final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) - -object ListResponse { - - def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = - ListResponse( - items = items, - meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) - - final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) - - object Meta { - def apply(itemsCount: Int, pagination: Pagination): Meta = - Meta(itemsCount, pagination.pageNumber, pagination.pageSize) - } - -} - -object `package` { - - implicit class FutureExtensions[T](future: Future[T]) { - def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = - future.transform(identity, { - case ExternalServiceException(_, _, Some(e)) => e - case t: Throwable => t - }) - } - - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } - - object ContextHeaders { - val AuthenticationTokenHeader: String = "Authorization" - val PermissionsTokenHeader: String = "Permissions" - val AuthenticationHeaderPrefix: String = "Bearer" - val ClientFingerprintHeader: String = "X-Client-Fingerprint" - val TrackingIdHeader: String = "X-Trace" - val StacktraceHeader: String = "X-Stacktrace" - val OriginatingIpHeader: String = "X-Forwarded-For" - val ResourceCount: String = "X-Resource-Count" - val PageCount: String = "X-Page-Count" - val TraceHeaderName: String = TracingDirectives.TraceHeaderName - val SpanHeaderName: String = TracingDirectives.SpanHeaderName - } - - val AllowedHeaders: Seq[String] = - Seq( - "Origin", - "X-Requested-With", - "Content-Type", - "Content-Length", - "Accept", - "X-Trace", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Headers", - "Server", - "Date", - ContextHeaders.ClientFingerprintHeader, - ContextHeaders.TrackingIdHeader, - ContextHeaders.TraceHeaderName, - ContextHeaders.SpanHeaderName, - ContextHeaders.StacktraceHeader, - ContextHeaders.AuthenticationTokenHeader, - ContextHeaders.OriginatingIpHeader, - ContextHeaders.ResourceCount, - ContextHeaders.PageCount, - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader, - "Traceparent" - ) - - def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - - def serviceContext: Directive1[ServiceRequestContext] = { - def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { - headers.map({ header => - if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( - ContextHeaders.AuthenticationHeaderPrefix)) { - Authorization(OAuth2BearerToken(header.value)) - } else header - })(collection.breakOut) - } - extractClientIP flatMap { remoteAddress => - mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => - extract(ctx => extractServiceContext(ctx.request, remoteAddress)) - } - } - } - - def respondWithCorsAllowedHeaders: Directive0 = { - respondWithHeaders( - List[HttpHeader]( - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - )) - } - - def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { - respondWithHeader { - `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) - } - } - - def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { - respondWithHeaders( - List[HttpHeader]( - Allow(methods.to[collection.immutable.Seq]), - `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) - )) - } - - def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = - new ServiceRequestContext( - extractTrackingId(request), - extractOriginatingIP(request, remoteAddress), - extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name === ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractFingerprintHash(request: HttpRequest): Option[String] = { - request.headers - .find(_.name === ContextHeaders.ClientFingerprintHeader) - .map(_.value()) - } - - def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { - request.headers - .find(_.name === ContextHeaders.OriginatingIpHeader) - .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) - .orElse(remoteAddress.toOption) - } - - def extractStacktrace(request: HttpRequest): Array[String] = - request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers - .filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || - h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || - h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || - h.name === Traceparent.name - } - .map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } - .toMap - } - - private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { - @annotation.tailrec - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse - else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) - } - } - } - - val indices = dirtyIndices(0, Nil) - - indices.headOption.fold(byteString) { head => - val builder = ByteString.newBuilder - builder ++= byteString.take(head) - - (indices :+ byteString.length).sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - case Seq(_) => // Should not match; sliding on at least 2 elements - assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") - } - builder.result - } - } - - val sanitizeRequestEntity: Directive0 = { - mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) - } - - val paginated: Directive1[Pagination] = - parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) - - private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = - (pageSizeOpt, pageNumberOpt) match { - case (Some(size), Some(number)) => Option(Pagination(size, number)) - case (None, None) => Option.empty[Pagination] - case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") - } - - val optionalPagination: Directive1[Option[Pagination]] = - parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) - - def paginationQuery(pagination: Pagination) = - Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) - - def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( - implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { - optionalPagination { pagination => - onSuccess(handler(pagination)) { - case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => - val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) - val headers = List( - RawHeader(ContextHeaders.ResourceCount, count.toString), - RawHeader(ContextHeaders.PageCount, pageCount.toString)) - - respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) - } - } - } - - private def extractSorting(sortingString: Option[String]): Sorting = { - val sortingFields = sortingString.fold(Seq.empty[SortingField])( - _.split(",") - .filter(_.length > 0) - .map { sortingParam => - if (sortingParam.startsWith("-")) { - SortingField(sortingParam.substring(1), SortingOrder.Desc) - } else { - val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam - SortingField(fieldName, SortingOrder.Asc) - } - } - .toSeq) - - Sorting(sortingFields) - } - - val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) - - def sortingQuery(sorting: Sorting): Seq[(String, String)] = { - val sortingString = sorting.sortingFields - .map { sortingField => - sortingField.sortingOrder match { - case SortingOrder.Asc => sortingField.name - case SortingOrder.Desc => s"-${sortingField.name}" - } - } - .mkString(",") - Seq("sort" -> sortingString) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala deleted file mode 100644 index 55f1a2e..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala +++ /dev/null @@ -1,24 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -trait ServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T -} - -trait SavingUsedServiceDiscovery { - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } -} - -class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala deleted file mode 100644 index d2e4bc3..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala +++ /dev/null @@ -1,87 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} -import xyz.driver.core.generators -import scalaz.Scalaz.{mapEqual, stringInstance} -import scalaz.syntax.equal._ -import xyz.driver.core.reporting.SpanContext -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.headers.Traceparent - -import scala.util.Try - -class ServiceRequestContext( - val trackingId: String = generators.nextUuid().toString, - val originatingIp: Option[InetAddress] = None, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - - def permissionsToken: Option[PermissionsToken] = - contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - - def withAuthToken(authToken: AuthToken): ServiceRequestContext = - new ServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) - ) - - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user - ) - - override def hashCode(): Int = - Seq[Any](trackingId, originatingIp, contextHeaders) - .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) - - override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => - trackingId === ctx.trackingId && - originatingIp == originatingIp && - contextHeaders === ctx.contextHeaders - case _ => false - } - - def spanContext: SpanContext = { - val validHeader = Try { - contextHeaders(Traceparent.name) - }.flatMap { value => - Traceparent.parse(value) - } - validHeader.map(_.spanContext).getOrElse(SpanContext.fresh()) - } - - override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" -} - -class AuthorizedServiceRequestContext[U <: User]( - override val trackingId: String = generators.nextUuid().toString, - override val originatingIp: Option[InetAddress] = None, - override val contextHeaders: Map[String, String] = Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} diff --git a/src/test/scala/xyz/driver/core/AuthTest.scala b/src/test/scala/xyz/driver/core/AuthTest.scala deleted file mode 100644 index 2e772fb..0000000 --- a/src/test/scala/xyz/driver/core/AuthTest.scala +++ /dev/null @@ -1,165 +0,0 @@ -package xyz.driver.core - -import akka.http.scaladsl.model.headers.{ - HttpChallenges, - OAuth2BearerToken, - RawHeader, - Authorization => AkkaAuthorization -} -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth._ -import xyz.driver.core.domain.Email -import xyz.driver.core.logging._ -import xyz.driver.core.rest._ -import xyz.driver.core.rest.auth._ -import xyz.driver.core.time.Time - -import scala.concurrent.Future -import scalaz.OptionT - -class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { - - case object TestRoleAllowedPermission extends Permission - case object TestRoleAllowedByTokenPermission extends Permission - case object TestRoleNotAllowedPermission extends Permission - - val TestRole = Role(Id("1"), Name("testRole")) - - val (publicKey, privateKey) = { - import java.security.KeyPairGenerator - - val keygen = KeyPairGenerator.getInstance("RSA") - keygen.initialize(2048) - - val keyPair = keygen.generateKeyPair() - (keyPair.getPublic, keyPair.getPrivate) - } - - val basicAuthorization: Authorization[User] = new Authorization[User] { - - override def userHasPermissions(user: User, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap - Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) - } - } - - val tokenIssuer = "users" - val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) - - val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) - - val authStatusService = new AuthProvider[User](authorization, NoLogger) { - override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = - OptionT.optionT[Future] { - if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { - Future.successful( - Some( - AuthTokenUserInfo( - Id[User]("1"), - Email("foo", "bar"), - emailVerified = true, - audience = "driver", - roles = Set(TestRole), - expirationTime = Time(1000000L) - ))) - } else { - Future.successful(Option.empty[User]) - } - } - } - - import authStatusService._ - - "'authorize' directive" should "throw error if auth token is not in the request" in { - - Get("/naive/attempt") ~> - authorize(TestRoleAllowedPermission) { user => - complete("Never going to be here") - } ~> - check { - // handled shouldBe false - rejections should contain( - AuthenticationFailedRejection( - AuthenticationFailedRejection.CredentialsMissing, - HttpChallenges.oAuth2(authStatusService.realm))) - } - } - - it should "throw error if authorized user does not have the requested permission" in { - - val referenceAuthToken = AuthToken("I am a test role's token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Post("/administration/attempt").addHeader( - referenceAuthHeader - ) ~> - authorize(TestRoleNotAllowedPermission) { user => - complete("Never going to get here") - } ~> - check { - handled shouldBe false - rejections should contain(AuthorizationFailedRejection) - } - } - - it should "pass and retrieve the token to client code, if token is in request and user has permission" in { - val referenceAuthToken = AuthToken("I am token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Get("/valid/attempt/?a=2&b=5").addHeader( - referenceAuthHeader - ) ~> - authorize(TestRoleAllowedPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized" - } - } - - it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in { - val referenceAuthToken = AuthToken("unprefixed_token") - - Get("/valid/attempt/?a=2&b=5").addHeader( - RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value) - ) ~> - authorize(TestRoleAllowedPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized" - } - } - - it should "authorize permission found in permissions token" in { - import spray.json._ - - val claim = JsObject( - Map( - "iss" -> JsString(tokenIssuer), - "sub" -> JsString("1"), - "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) - )).prettyPrint - val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) - val referenceAuthToken = AuthToken("I am token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Get("/alic/attempt/?a=2&b=5") - .addHeader(referenceAuthHeader) - .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> - authorize(TestRoleAllowedByTokenPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" - } - } -} diff --git a/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/src/test/scala/xyz/driver/core/GeneratorsTest.scala deleted file mode 100644 index 7e740a4..0000000 --- a/src/test/scala/xyz/driver/core/GeneratorsTest.scala +++ /dev/null @@ -1,264 +0,0 @@ -package xyz.driver.core - -import org.scalatest.{Assertions, FlatSpec, Matchers} - -import scala.collection.immutable.IndexedSeq - -class GeneratorsTest extends FlatSpec with Matchers with Assertions { - import generators._ - - "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { - - val generatedId1 = nextId[String]() - val generatedId2 = nextId[String]() - val generatedId3 = nextId[Long]() - - generatedId1.length should be >= 0 - generatedId2.length should be >= 0 - generatedId3.length should be >= 0 - generatedId1 should not be generatedId2 - generatedId2 should !==(generatedId3) - } - - it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { - - val generatedLimitedId1 = nextId[String](5) - val generatedLimitedId2 = nextId[String](4) - val generatedLimitedId3 = nextId[Long](3) - - generatedLimitedId1.length should be >= 0 - generatedLimitedId1.length should be < 6 - generatedLimitedId2.length should be >= 0 - generatedLimitedId2.length should be < 5 - generatedLimitedId3.length should be >= 0 - generatedLimitedId3.length should be < 4 - generatedLimitedId1 should not be generatedLimitedId2 - generatedLimitedId2 should !==(generatedLimitedId3) - } - - it should "be able to generate com.drivergrp.core.Name names" in { - - Seq.fill(10)(nextName[String]()).distinct.size should be > 1 - nextName[String]().value.length should be >= 0 - - val fixedLengthName = nextName[String](10) - fixedLengthName.length should be <= 10 - assert(!fixedLengthName.value.exists(_.isControl)) - } - - it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { - - assert(nextNonEmptyName[String]().value.value.nonEmpty) - } - - it should "be able to generate proper UUIDs" in { - - nextUuid() should not be nextUuid() - nextUuid().toString.length should be(36) - } - - it should "be able to generate new Revisions" in { - - nextRevision[String]() should not be nextRevision[String]() - nextRevision[String]().id.length should be > 0 - } - - it should "be able to generate strings" in { - - nextString() should not be nextString() - nextString().length should be >= 0 - - val fixedLengthString = nextString(20) - fixedLengthString.length should be <= 20 - assert(!fixedLengthString.exists(_.isControl)) - } - - it should "be able to generate strings non-empty strings whic are non empty" in { - - assert(nextNonEmptyString().value.nonEmpty) - } - - it should "be able to generate options which are sometimes have values and sometimes not" in { - - val generatedOption = nextOption("2") - - generatedOption should not contain "1" - assert(generatedOption === Some("2") || generatedOption === None) - } - - it should "be able to generate a pair of two generated values" in { - - val constantPair = nextPair("foo", 1L) - constantPair._1 should be("foo") - constantPair._2 should be(1L) - - val generatedPair = nextPair(nextId[Int](), nextName[Int]()) - - generatedPair._1.length should be > 0 - generatedPair._2.length should be > 0 - - nextPair(nextId[Int](), nextName[Int]()) should not be - nextPair(nextId[Int](), nextName[Int]()) - } - - it should "be able to generate a triad of two generated values" in { - - val constantTriad = nextTriad("foo", "bar", 1L) - constantTriad._1 should be("foo") - constantTriad._2 should be("bar") - constantTriad._3 should be(1L) - - val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) - - generatedTriad._1.length should be > 0 - generatedTriad._2.length should be > 0 - generatedTriad._3 should be >= BigDecimal(0.00) - - nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be - nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) - } - - it should "be able to generate a time value" in { - - val generatedTime = nextTime() - val currentTime = System.currentTimeMillis() - - generatedTime.millis should be >= 0L - generatedTime.millis should be <= currentTime - } - - it should "be able to generate a time range value" in { - - val generatedTimeRange = nextTimeRange() - val currentTime = System.currentTimeMillis() - - generatedTimeRange.start.millis should be >= 0L - generatedTimeRange.start.millis should be <= currentTime - generatedTimeRange.end.millis should be >= 0L - generatedTimeRange.end.millis should be <= currentTime - generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis - } - - it should "be able to generate a BigDecimal value" in { - - val defaultGeneratedBigDecimal = nextBigDecimal() - - defaultGeneratedBigDecimal should be >= BigDecimal(0.00) - defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) - defaultGeneratedBigDecimal.precision should be(2) - - val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) - - unitIntervalBigDecimal should be >= BigDecimal(0.00) - unitIntervalBigDecimal should be <= BigDecimal(1.00) - unitIntervalBigDecimal.precision should be(8) - } - - it should "be able to generate a specific value from a set of values" in { - - val possibleOptions = Set(1, 3, 5, 123, 0, 9) - - val pick1 = generators.oneOf(possibleOptions) - val pick2 = generators.oneOf(possibleOptions) - val pick3 = generators.oneOf(possibleOptions) - - possibleOptions should contain(pick1) - possibleOptions should contain(pick2) - possibleOptions should contain(pick3) - - val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) - val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) - val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) - - possibleOptions should contain(pick4) - possibleOptions should contain(pick5) - possibleOptions should contain(pick6) - - Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 - } - - it should "be able to generate a specific value from an enumeratum enum" in { - - import enumeratum._ - sealed trait TestEnumValue extends EnumEntry - object TestEnum extends Enum[TestEnumValue] { - case object Value1 extends TestEnumValue - case object Value2 extends TestEnumValue - case object Value3 extends TestEnumValue - case object Value4 extends TestEnumValue - val values: IndexedSeq[TestEnumValue] = findValues - } - - val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) - - TestEnum.values should contain allElementsOf picks - picks.toSet.size should be >= 1 - } - - it should "be able to generate array with values generated by generators" in { - - val arrayOfTimes = arrayOf(nextTime(), 16) - arrayOfTimes.length should be <= 16 - - val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) - arrayOfBigDecimals.length should be <= 8 - } - - it should "be able to generate seq with values generated by generators" in { - - val seqOfTimes = seqOf(nextTime(), 16) - seqOfTimes.size should be <= 16 - - val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) - seqOfBigDecimals.size should be <= 8 - } - - it should "be able to generate vector with values generated by generators" in { - - val vectorOfTimes = vectorOf(nextTime(), 16) - vectorOfTimes.size should be <= 16 - - val vectorOfStrings = seqOf(nextString(), 8) - vectorOfStrings.size should be <= 8 - } - - it should "be able to generate list with values generated by generators" in { - - val listOfTimes = listOf(nextTime(), 16) - listOfTimes.size should be <= 16 - - val listOfBigDecimals = seqOf(nextBigDecimal(), 8) - listOfBigDecimals.size should be <= 8 - } - - it should "be able to generate set with values generated by generators" in { - - val setOfTimes = vectorOf(nextTime(), 16) - setOfTimes.size should be <= 16 - - val setOfBigDecimals = seqOf(nextBigDecimal(), 8) - setOfBigDecimals.size should be <= 8 - } - - it should "be able to generate maps with keys and values generated by generators" in { - - val generatedConstantMap = mapOf("key", 123, 10) - generatedConstantMap.size should be <= 1 - assert(generatedConstantMap.keys.forall(_ == "key")) - assert(generatedConstantMap.values.forall(_ == 123)) - - val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) - assert(generatedMap.keys.forall(_.length <= 10)) - assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) - } - - it should "compose deeply" in { - - val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) - - generatedNestedMap.size should be <= 10 - generatedNestedMap.keySet.size should be <= 10 - generatedNestedMap.values.size should be <= 10 - assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) - } -} diff --git a/src/test/scala/xyz/driver/core/JsonTest.scala b/src/test/scala/xyz/driver/core/JsonTest.scala deleted file mode 100644 index fd693f9..0000000 --- a/src/test/scala/xyz/driver/core/JsonTest.scala +++ /dev/null @@ -1,521 +0,0 @@ -package xyz.driver.core - -import java.net.InetAddress -import java.time.{Instant, LocalDate} - -import akka.http.scaladsl.model.Uri -import akka.http.scaladsl.server.PathMatcher -import akka.http.scaladsl.server.PathMatcher.Matched -import com.neovisionaries.i18n.{CountryCode, CurrencyCode} -import enumeratum._ -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.numeric.Positive -import eu.timepit.refined.refineMV -import org.scalatest.{Inspectors, Matchers, WordSpec} -import spray.json._ -import xyz.driver.core.TestTypes.CustomGADT -import xyz.driver.core.auth.AuthCredentials -import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.json._ -import xyz.driver.core.json.enumeratum.HasJsonFormat -import xyz.driver.core.tagging._ -import xyz.driver.core.time.provider.SystemTimeProvider -import xyz.driver.core.time.{Time, TimeOfDay} - -import scala.collection.immutable.IndexedSeq -import scala.language.postfixOps - -class JsonTest extends WordSpec with Matchers with Inspectors { - import DefaultJsonProtocol._ - - "Json format for Id" should { - "read and write correct JSON" in { - - val referenceId = Id[String]("1312-34A") - - val writtenJson = json.idFormat.write(referenceId) - writtenJson.prettyPrint should be("\"1312-34A\"") - - val parsedId = json.idFormat.read(writtenJson) - parsedId should be(referenceId) - } - } - - "Json format for @@" should { - "read and write correct JSON" in { - trait Irrelevant - val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] - - val format = json.taggedFormat[Id[JsonTest], Irrelevant] - - val writtenJson = format.write(reference) - writtenJson shouldBe JsString("SomeID") - - val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) - parsedId shouldBe reference - } - - "read and write correct JSON when there's an implicit conversion defined" in { - val input = " some string " - - JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim() - - val trimmed: String @@ Trimmed = input - trimmed.toJson shouldBe JsString(trimmed) - } - } - - "Json format for Name" should { - "read and write correct JSON" in { - - val referenceName = Name[String]("Homer") - - val writtenJson = json.nameFormat.write(referenceName) - writtenJson.prettyPrint should be("\"Homer\"") - - val parsedName = json.nameFormat.read(writtenJson) - parsedName should be(referenceName) - } - - "read and write correct JSON for Name @@ Trimmed" in { - trait Irrelevant - JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name") - - val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ") - trimmed.toJson shouldBe JsString("some name") - } - } - - "Json format for NonEmptyName" should { - "read and write correct JSON" in { - - val jsonFormat = json.nonEmptyNameFormat[String] - - val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) - - val writtenJson = jsonFormat.write(referenceNonEmptyName) - writtenJson.prettyPrint should be("\"Homer\"") - - val parsedNonEmptyName = jsonFormat.read(writtenJson) - parsedNonEmptyName should be(referenceNonEmptyName) - } - } - - "Json format for Time" should { - "read and write correct JSON" in { - - val referenceTime = new SystemTimeProvider().currentTime() - - val writtenJson = json.timeFormat.write(referenceTime) - writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") - - val parsedTime = json.timeFormat.read(writtenJson) - parsedTime should be(referenceTime) - } - - "read from inputs compatible with Instant" in { - val referenceTime = new SystemTimeProvider().currentTime() - - val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString)) - - forAll(jsons) { json => - json.convertTo[Time] shouldBe referenceTime - } - } - } - - "Json format for TimeOfDay" should { - "read and write correct JSON" in { - val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") - val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") - val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) - writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) - val parsed = json.timeOfDayFormat.read(writtenJson) - parsed should be(referenceTimeOfDay) - } - } - - "Json format for Date" should { - "read and write correct JSON" in { - import date._ - - val referenceDate = Date(1941, Month.DECEMBER, 7) - - val writtenJson = json.dateFormat.write(referenceDate) - writtenJson.prettyPrint should be("\"1941-12-07\"") - - val parsedDate = json.dateFormat.read(writtenJson) - parsedDate should be(referenceDate) - } - } - - "Json format for java.time.Instant" should { - - val isoString = "2018-08-08T08:08:08.888Z" - val instant = Instant.parse(isoString) - - "read correct JSON when value is an epoch milli number" in { - JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant - } - - "read correct JSON when value is an ISO timestamp string" in { - JsString(isoString).convertTo[Instant] shouldBe instant - } - - "read correct JSON when value is an object with nested 'timestamp'/millis field" in { - val json = JsObject( - "timestamp" -> JsNumber(instant.toEpochMilli) - ) - - json.convertTo[Instant] shouldBe instant - } - - "write correct JSON" in { - instant.toJson shouldBe JsString(isoString) - } - } - - "Path matcher for Instant" should { - - val isoString = "2018-08-08T08:08:08.888Z" - val instant = Instant.parse(isoString) - - val matcher = PathMatcher("foo") / InstantInPath / - - "read instant from millis" in { - matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) - } - - "read instant from ISO timestamp string" in { - matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) - } - } - - "Json format for java.time.LocalDate" should { - - "read and write correct JSON" in { - val dateString = "2018-08-08" - val date = LocalDate.parse(dateString) - - date.toJson shouldBe JsString(dateString) - JsString(dateString).convertTo[LocalDate] shouldBe date - } - } - - "Json format for Revision" should { - "read and write correct JSON" in { - - val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") - - val writtenJson = json.revisionFormat.write(referenceRevision) - writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") - - val parsedRevision = json.revisionFormat.read(writtenJson) - parsedRevision should be(referenceRevision) - } - } - - "Json format for Email" should { - "read and write correct JSON" in { - - val referenceEmail = Email("test", "drivergrp.com") - - val writtenJson = json.emailFormat.write(referenceEmail) - writtenJson should be("\"test@drivergrp.com\"".parseJson) - - val parsedEmail = json.emailFormat.read(writtenJson) - parsedEmail should be(referenceEmail) - } - } - - "Json format for PhoneNumber" should { - "read and write correct JSON" in { - - val referencePhoneNumber = PhoneNumber("1", "4243039608") - - val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) - writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) - - val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) - parsedPhoneNumber should be(referencePhoneNumber) - } - - "reject an invalid phone number" in { - val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson - - intercept[DeserializationException] { - json.phoneNumberFormat.read(phoneJson) - }.getMessage shouldBe "Invalid phone number" - } - - "parse phone number from string" in { - JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608") - } - } - - "Path matcher for PhoneNumber" should { - "read valid phone number" in { - val string = "+14243039608x23" - val phone = PhoneNumber("1", "4243039608", Some("23")) - - val matcher = PathMatcher("foo") / PhoneInPath - - matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone)) - } - } - - "Json format for ADT mappings" should { - "read and write correct JSON" in { - - sealed trait EnumVal - case object Val1 extends EnumVal - case object Val2 extends EnumVal - case object Val3 extends EnumVal - - val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) - - val referenceEnumValue1 = Val2 - val referenceEnumValue2 = Val3 - - val writtenJson1 = format.write(referenceEnumValue1) - writtenJson1.prettyPrint should be("\"b\"") - - val writtenJson2 = format.write(referenceEnumValue2) - writtenJson2.prettyPrint should be("\"c\"") - - val parsedEnumValue1 = format.read(writtenJson1) - val parsedEnumValue2 = format.read(writtenJson2) - - parsedEnumValue1 should be(referenceEnumValue1) - parsedEnumValue2 should be(referenceEnumValue2) - } - } - - "Json format for Enums (external)" should { - "read and write correct JSON" in { - - sealed trait MyEnum extends EnumEntry - object MyEnum extends Enum[MyEnum] { - case object Val1 extends MyEnum - case object `Val 2` extends MyEnum - case object `Val/3` extends MyEnum - - val values: IndexedSeq[MyEnum] = findValues - } - - val format = new enumeratum.EnumJsonFormat(MyEnum) - - val referenceEnumValue1 = MyEnum.`Val 2` - val referenceEnumValue2 = MyEnum.`Val/3` - - val writtenJson1 = format.write(referenceEnumValue1) - writtenJson1 shouldBe JsString("Val 2") - - val writtenJson2 = format.write(referenceEnumValue2) - writtenJson2 shouldBe JsString("Val/3") - - val parsedEnumValue1 = format.read(writtenJson1) - val parsedEnumValue2 = format.read(writtenJson2) - - parsedEnumValue1 shouldBe referenceEnumValue1 - parsedEnumValue2 shouldBe referenceEnumValue2 - - intercept[DeserializationException] { - format.read(JsString("Val4")) - }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" - } - } - - "Json format for Enums (automatic)" should { - "read and write correct JSON and not require import" in { - - sealed trait MyEnum extends EnumEntry - object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { - case object Val1 extends MyEnum - case object `Val 2` extends MyEnum - case object `Val/3` extends MyEnum - - val values: IndexedSeq[MyEnum] = findValues - } - - val referenceEnumValue1: MyEnum = MyEnum.`Val 2` - val referenceEnumValue2: MyEnum = MyEnum.`Val/3` - - val writtenJson1 = referenceEnumValue1.toJson - writtenJson1 shouldBe JsString("Val 2") - - val writtenJson2 = referenceEnumValue2.toJson - writtenJson2 shouldBe JsString("Val/3") - - import spray.json._ - - val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] - val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] - - parsedEnumValue1 should be(referenceEnumValue1) - parsedEnumValue2 should be(referenceEnumValue2) - - intercept[DeserializationException] { - JsString("Val4").convertTo[MyEnum] - }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" - } - } - - // Should be defined outside of case to have a TypeTag - case class CustomWrapperClass(value: Int) - - "Json format for Value classes" should { - "read and write correct JSON" in { - - val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) - - val referenceValue1 = CustomWrapperClass(-2) - val referenceValue2 = CustomWrapperClass(10) - - val writtenJson1 = format.write(referenceValue1) - writtenJson1.prettyPrint should be("-2") - - val writtenJson2 = format.write(referenceValue2) - writtenJson2.prettyPrint should be("10") - - val parsedValue1 = format.read(writtenJson1) - val parsedValue2 = format.read(writtenJson2) - - parsedValue1 should be(referenceValue1) - parsedValue2 should be(referenceValue2) - } - } - - "Json format for classes GADT" should { - "read and write correct JSON" in { - - import CustomGADT._ - import DefaultJsonProtocol._ - implicit val case1Format = jsonFormat1(GadtCase1) - implicit val case2Format = jsonFormat1(GadtCase2) - implicit val case3Format = jsonFormat1(GadtCase3) - - val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { - case _: CustomGADT.GadtCase1 => "case1" - case _: CustomGADT.GadtCase2 => "case2" - case _: CustomGADT.GadtCase3 => "case3" - } { - case "case1" => case1Format - case "case2" => case2Format - case "case3" => case3Format - } - - val referenceValue1 = CustomGADT.GadtCase1("4") - val referenceValue2 = CustomGADT.GadtCase2("Hi!") - - val writtenJson1 = format.write(referenceValue1) - writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) - - val writtenJson2 = format.write(referenceValue2) - writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) - - val parsedValue1 = format.read(writtenJson1) - val parsedValue2 = format.read(writtenJson2) - - parsedValue1 should be(referenceValue1) - parsedValue2 should be(referenceValue2) - } - } - - "Json format for a Refined value" should { - "read and write correct JSON" in { - - val jsonFormat = json.refinedJsonFormat[Int, Positive] - - val referenceRefinedNumber = refineMV[Positive](42) - - val writtenJson = jsonFormat.write(referenceRefinedNumber) - writtenJson should be("42".parseJson) - - val parsedRefinedNumber = jsonFormat.read(writtenJson) - parsedRefinedNumber should be(referenceRefinedNumber) - } - } - - "InetAddress format" should { - "read and write correct JSON" in { - val address = InetAddress.getByName("127.0.0.1") - val json = inetAddressFormat.write(address) - - json shouldBe JsString("127.0.0.1") - - val parsed = inetAddressFormat.read(json) - parsed shouldBe address - } - - "throw a DeserializationException for an invalid IP Address" in { - assertThrows[DeserializationException] { - val invalidAddress = JsString("foobar:") - inetAddressFormat.read(invalidAddress) - } - } - } - - "AuthCredentials format" should { - "read and write correct JSON" in { - val email = Email("someone", "noehere.com") - val phoneId = PhoneNumber.parse("1 207 8675309") - val password = "nopassword" - - phoneId.isDefined should be(true) // test this real quick - - val emailAuth = AuthCredentials(email.toString, password) - val pnAuth = AuthCredentials(phoneId.get.toString, password) - - val emailWritten = authCredentialsFormat.write(emailAuth) - emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) - - val phoneWritten = authCredentialsFormat.write(pnAuth) - phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) - - val identifierEmailParsed = - authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) - var written = authCredentialsFormat.write(identifierEmailParsed) - written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) - - val emailEmailParsed = - authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) - written = authCredentialsFormat.write(emailEmailParsed) - written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) - - } - } - - "CountryCode format" should { - "read and write correct JSON" in { - val samples = Seq( - "US" -> CountryCode.US, - "CN" -> CountryCode.CN, - "AT" -> CountryCode.AT - ) - - forAll(samples) { - case (serialized, enumValue) => - countryCodeFormat.write(enumValue) shouldBe JsString(serialized) - countryCodeFormat.read(JsString(serialized)) shouldBe enumValue - } - } - } - - "CurrencyCode format" should { - "read and write correct JSON" in { - val samples = Seq( - "USD" -> CurrencyCode.USD, - "CNY" -> CurrencyCode.CNY, - "EUR" -> CurrencyCode.EUR - ) - - forAll(samples) { - case (serialized, enumValue) => - currencyCodeFormat.write(enumValue) shouldBe JsString(serialized) - currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue - } - } - } - -} diff --git a/src/test/scala/xyz/driver/core/TestTypes.scala b/src/test/scala/xyz/driver/core/TestTypes.scala deleted file mode 100644 index bb25deb..0000000 --- a/src/test/scala/xyz/driver/core/TestTypes.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core - -object TestTypes { - - sealed trait CustomGADT { - val field: String - } - - object CustomGADT { - final case class GadtCase1(field: String) extends CustomGADT - final case class GadtCase2(field: String) extends CustomGADT - final case class GadtCase3(field: String) extends CustomGADT - } -} diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala deleted file mode 100644 index 324c8d8..0000000 --- a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala +++ /dev/null @@ -1,89 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{HttpMethod, StatusCodes} -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import com.typesafe.config.ConfigFactory -import org.scalatest.{AsyncFlatSpec, Matchers} -import xyz.driver.core.app.{DriverApp, SimpleModule} - -class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { - val config = ConfigFactory.parseString(""" - |application { - | cors { - | allowedOrigins: ["example.com"] - | } - |} - """.stripMargin).withFallback(ConfigFactory.load) - - val origin = Origin(HttpOrigin("https", Host("example.com"))) - val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) - val allowedMethods: collection.immutable.Seq[HttpMethod] = { - import akka.http.scaladsl.model.HttpMethods._ - collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) - } - - import scala.reflect.runtime.universe.typeOf - class TestApp(testRoute: Route) - extends DriverApp( - appName = "test-app", - version = "0.0.1", - gitHash = "deadb33f", - modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), - config = config, - log = xyz.driver.core.logging.NoLogger - ) - - it should "respond with the correct CORS headers for the swagger OPTIONS route" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } - - it should "respond with the correct CORS headers for the test route" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - } - } - - it should "respond with the correct CORS headers for a concatenated route" in { - val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) - Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - } - } - - it should "allow subdomains of allowed origin suffixes" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) - } - } - - it should "respond with default domains for invalid origins" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) - } - } - - it should "respond with Pragma and Cache-Control (no-cache) headers" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - header("Pragma").map(_.value()) should contain("no-cache") - header[`Cache-Control`].map(_.value()) should contain("no-cache") - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala deleted file mode 100644 index cc0019a..0000000 --- a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala +++ /dev/null @@ -1,121 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes -import akka.http.scaladsl.model.headers.Connection -import akka.http.scaladsl.server.Directives.{complete => akkaComplete} -import akka.http.scaladsl.server.{Directives, RejectionHandler, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import com.typesafe.scalalogging.Logger -import org.scalatest.{AsyncFlatSpec, Matchers} -import xyz.driver.core.json.serviceExceptionFormat -import xyz.driver.core.logging.NoLogger -import xyz.driver.core.rest.errors._ - -import scala.concurrent.Future - -class DriverRouteTest - extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { - class TestRoute(override val route: Route) extends DriverRoute { - override def log: Logger = NoLogger - } - - "DriverRoute" should "respond with 200 OK for a basic route" in { - val route = new TestRoute(akkaComplete(StatusCodes.OK)) - - Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.OK - } - } - - it should "respond with a 401 for an InvalidInputException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - responseAs[ServiceException] shouldBe InvalidInputException() - } - } - - it should "respond with a 403 for InvalidActionException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.Forbidden - responseAs[ServiceException] shouldBe InvalidActionException() - } - } - - it should "respond with a 404 for ResourceNotFoundException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - responseAs[ServiceException] shouldBe ResourceNotFoundException() - } - } - - it should "respond with a 500 for ExternalServiceException" in { - val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) - val route = new TestRoute(akkaComplete(Future.failed[String](error))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.InternalServerError - responseAs[ServiceException] shouldBe error - } - } - - it should "allow pass-through of external service exceptions" in { - val innerError = InvalidInputException() - val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) - val future = Future.failed[String](error) - val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - responseAs[ServiceException] shouldBe innerError - } - } - - it should "respond with a 503 for ExternalServiceTimeoutException" in { - val error = ExternalServiceTimeoutException("GET /api/v1/users/") - val route = new TestRoute(akkaComplete(Future.failed[String](error))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.GatewayTimeout - responseAs[ServiceException] shouldBe error - } - } - - it should "respond with a 500 for DatabaseException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.InternalServerError - responseAs[ServiceException] shouldBe DatabaseException() - } - } - - it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { - val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() - val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) - - Get("/foo") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(Connection("close")) - } - - Get("/bar") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.NotFound - headers should contain(Connection("close")) - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala deleted file mode 100644 index 987717d..0000000 --- a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala +++ /dev/null @@ -1,101 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.`Content-Type` -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import spray.json._ -import xyz.driver.core.{Id, Name} -import xyz.driver.core.json._ - -import scala.concurrent.Future - -class PatchDirectivesTest - extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol - with Directives with PatchDirectives { - case class Bar(name: Name[Bar], size: Int) - case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) - implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) - implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) - - val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) - - def route(retrieve: => Future[Option[Foo]]): Route = - Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => - entity(as[Patchable[Foo]]) { fooPatchable => - mergePatch(fooPatchable, retrieve) { updatedFoo => - complete(updatedFoo) - } - } - }) - - val MergePatchContentType = ContentType(`application/merge-patch+json`) - val ContentTypeHeader = `Content-Type`(MergePatchContentType) - def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = - HttpEntity(contentType, json) - - "PatchSupport" should "allow partial updates to an existing object" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4) - } - } - - it should "merge deeply nested objects" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 404 if the object is not found" in { - val fooRetrieve = Future.successful(None) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - } - } - - it should "handle nulls on optional values correctly" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = None) - } - } - - it should "handle optional values correctly when old value is null" in { - val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 400 for nulls on non-optional values" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - } - } - - it should "return a 415 for incorrect Content-Type" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { - status shouldBe StatusCodes.UnsupportedMediaType - responseAs[String] should include("application/merge-patch+json") - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/RestTest.scala b/src/test/scala/xyz/driver/core/rest/RestTest.scala deleted file mode 100644 index 19e4ed1..0000000 --- a/src/test/scala/xyz/driver/core/rest/RestTest.scala +++ /dev/null @@ -1,151 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.StatusCodes -import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import akka.util.ByteString -import org.scalatest.{Matchers, WordSpec} -import xyz.driver.core.rest - -import scala.concurrent.Future -import scala.util.Random - -class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { - "`escapeScriptTags` function" should { - "escape script tags properly" in { - val dirtyString = " - complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") - } - "accept a pagination" in { - Get("/?pageNumber=2&pageSize=42") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "2,42") - } - } - "provide a default pagination" in { - Get("/") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "1,100") - } - } - "provide default values for a partial pagination" in { - Get("/?pageSize=2") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "1,2") - } - } - "reject an invalid pagination" in { - Get("/?pageNumber=-1") ~> route ~> check { - assert(rejection.isInstanceOf[ValidationRejection]) - } - } - } - - "optional paginated directive" should { - val route: Route = rest.optionalPagination { paginated => - complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) - } - "accept a pagination" in { - Get("/?pageNumber=2&pageSize=42") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "2,42") - } - } - "without pagination" in { - Get("/") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "no pagination") - } - } - "reject an invalid pagination" in { - Get("/?pageNumber=1") ~> route ~> check { - assert(rejection.isInstanceOf[ValidationRejection]) - } - } - } - - "completeWithPagination directive" when { - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json.DefaultJsonProtocol._ - - val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString) - val route: Route = - parameter('empty.as[Boolean] ? false) { isEmpty => - completeWithPagination[String] { - case Some(pagination) if isEmpty => - Future.successful(ListResponse(Seq(), 0, Some(pagination))) - case Some(pagination) => - val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize) - Future.successful(ListResponse(filtered, data.size, Some(pagination))) - case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None)) - case None => Future.successful(ListResponse(data, data.size, None)) - } - } - - "pagination is defined" should { - "return a response with pagination headers" in { - Get("/?pageNumber=2&pageSize=10") ~> route ~> check { - responseAs[Seq[String]] shouldBe data.slice(10, 20) - header(ContextHeaders.ResourceCount).map(_.value) should contain("103") - header(ContextHeaders.PageCount).map(_.value) should contain("11") - } - } - - "disallow pageSize <= 0" in { - Get("/?pageNumber=2&pageSize=0") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - - Get("/?pageNumber=2&pageSize=-1") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - } - - "disallow pageNumber <= 0" in { - Get("/?pageNumber=0&pageSize=10") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - - Get("/?pageNumber=-1&pageSize=10") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - } - - "return PageCount == 0 if returning an empty list" in { - Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check { - responseAs[Seq[String]] shouldBe empty - header(ContextHeaders.ResourceCount).map(_.value) should contain("0") - header(ContextHeaders.PageCount).map(_.value) should contain("0") - } - } - } - - "pagination is not defined" should { - "return a response with pagination headers and PageCount == 1" in { - Get("/") ~> route ~> check { - responseAs[Seq[String]] shouldBe data - header(ContextHeaders.ResourceCount).map(_.value) should contain("103") - header(ContextHeaders.PageCount).map(_.value) should contain("1") - } - } - - "return PageCount == 0 if returning an empty list" in { - Get("/?empty=true") ~> route ~> check { - responseAs[Seq[String]] shouldBe empty - header(ContextHeaders.ResourceCount).map(_.value) should contain("0") - header(ContextHeaders.PageCount).map(_.value) should contain("0") - } - } - } - } -} -- cgit v1.2.3