From 4d1197099ce4e721c18bf4cacbb2e1980e4210b5 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 12 Sep 2018 16:40:57 -0700 Subject: Move REST functionality to separate project --- .../src/main/scala/xyz/driver/core/auth.scala | 43 ++ .../main/scala/xyz/driver/core/generators.scala | 143 ++++++ .../src/main/scala/xyz/driver/core/json.scala | 398 ++++++++++++++++ .../scala/xyz/driver/core/rest/DnsDiscovery.scala | 11 + .../scala/xyz/driver/core/rest/DriverRoute.scala | 122 +++++ .../core/rest/HttpRestServiceTransport.scala | 103 ++++ .../xyz/driver/core/rest/PatchDirectives.scala | 104 ++++ .../xyz/driver/core/rest/PooledHttpClient.scala | 67 +++ .../scala/xyz/driver/core/rest/ProxyRoute.scala | 26 + .../scala/xyz/driver/core/rest/RestService.scala | 86 ++++ .../xyz/driver/core/rest/ServiceDescriptor.scala | 16 + .../driver/core/rest/SingleRequestHttpClient.scala | 29 ++ .../main/scala/xyz/driver/core/rest/Swagger.scala | 144 ++++++ .../core/rest/auth/AlwaysAllowAuthorization.scala | 14 + .../xyz/driver/core/rest/auth/AuthProvider.scala | 75 +++ .../xyz/driver/core/rest/auth/Authorization.scala | 11 + .../core/rest/auth/AuthorizationResult.scala | 22 + .../core/rest/auth/CachedTokenAuthorization.scala | 55 +++ .../core/rest/auth/ChainedAuthorization.scala | 27 ++ .../core/rest/directives/AuthDirectives.scala | 19 + .../core/rest/directives/CorsDirectives.scala | 72 +++ .../driver/core/rest/directives/Directives.scala | 6 + .../driver/core/rest/directives/PathMatchers.scala | 85 ++++ .../core/rest/directives/Unmarshallers.scala | 40 ++ .../driver/core/rest/errors/serviceException.scala | 27 ++ .../xyz/driver/core/rest/headers/Traceparent.scala | 33 ++ .../main/scala/xyz/driver/core/rest/package.scala | 323 +++++++++++++ .../xyz/driver/core/rest/serviceDiscovery.scala | 24 + .../driver/core/rest/serviceRequestContext.scala | 87 ++++ .../src/test/scala/xyz/driver/core/AuthTest.scala | 165 +++++++ .../scala/xyz/driver/core/GeneratorsTest.scala | 264 +++++++++++ .../src/test/scala/xyz/driver/core/JsonTest.scala | 521 +++++++++++++++++++++ .../src/test/scala/xyz/driver/core/TestTypes.scala | 14 + .../scala/xyz/driver/core/rest/DriverAppTest.scala | 89 ++++ .../xyz/driver/core/rest/DriverRouteTest.scala | 121 +++++ .../xyz/driver/core/rest/PatchDirectivesTest.scala | 101 ++++ .../test/scala/xyz/driver/core/rest/RestTest.scala | 151 ++++++ 37 files changed, 3638 insertions(+) create mode 100644 core-rest/src/main/scala/xyz/driver/core/auth.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/generators.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/json.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/package.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala create mode 100644 core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/AuthTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/JsonTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/TestTypes.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala create mode 100644 core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala (limited to 'core-rest') diff --git a/core-rest/src/main/scala/xyz/driver/core/auth.scala b/core-rest/src/main/scala/xyz/driver/core/auth.scala new file mode 100644 index 0000000..896bd89 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/auth.scala @@ -0,0 +1,43 @@ +package xyz.driver.core + +import xyz.driver.core.domain.Email +import xyz.driver.core.time.Time +import scalaz.Equal + +object auth { + + trait Permission + + final case class Role(id: Id[Role], name: Name[Role]) { + + def oneOf(roles: Role*): Boolean = roles.contains(this) + + def oneOf(roles: Set[Role]): Boolean = roles.contains(this) + } + + object Role { + implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) + } + + trait User { + def id: Id[User] + } + + final case class AuthToken(value: String) + + final case class AuthTokenUserInfo( + id: Id[User], + email: Email, + emailVerified: Boolean, + audience: String, + roles: Set[Role], + expirationTime: Time) + extends User + + final case class RefreshToken(value: String) + final case class PermissionsToken(value: String) + + final case class PasswordHash(value: String) + + final case class AuthCredentials(identifier: String, password: String) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/generators.scala b/core-rest/src/main/scala/xyz/driver/core/generators.scala new file mode 100644 index 0000000..d00b6dd --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/generators.scala @@ -0,0 +1,143 @@ +package xyz.driver.core + +import enumeratum._ +import java.math.MathContext +import java.time.{Instant, LocalDate, ZoneOffset} +import java.util.UUID + +import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} +import xyz.driver.core.date.{Date, DayOfWeek} + +import scala.reflect.ClassTag +import scala.util.Random +import eu.timepit.refined.refineV +import eu.timepit.refined.api.Refined +import eu.timepit.refined.collection._ + +object generators { + + private val random = new Random + import random._ + private val secureRandom = new java.security.SecureRandom() + + private val DefaultMaxLength = 10 + private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet + private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" + private val Numbers = "0123456789" + + private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { + val builder = new StringBuilder + for (_ <- 0 until length) { + builder += chars(secureRandom.nextInt(chars.length)) + } + builder.result() + } + + /** Creates a random invitation token. + * + * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It + * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) + + @deprecated( + "The term 'token' is too generic and security and readability conventions are not well defined. " + + "Services should implement their own version that suits their security requirements.", + "1.11.0" + ) + def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) + + def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue + + def nextBoolean(): Boolean = random.nextBoolean() + + def nextDouble(): Double = random.nextDouble() + + def nextId[T](): Id[T] = Id[T](nextUuid().toString) + + def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) + + def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) + + def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) + + def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) + + def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = + NonEmptyName[T](nextNonEmptyString(maxLength)) + + def nextUuid(): UUID = java.util.UUID.randomUUID + + def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) + + def nextString(maxLength: Int = DefaultMaxLength): String = + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + + def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { + refineV[NonEmpty]( + (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString + ).right.get + } + + def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None + + def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) + + def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) + + def nextInstant(): Instant = Instant.ofEpochMilli(math.abs(nextLong() % System.currentTimeMillis)) + + def nextTime(): Time = nextInstant() + + def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) + + def nextTimeRange(): TimeRange = { + val oneTime = nextTime() + val anotherTime = nextTime() + + TimeRange( + Time(scala.math.min(oneTime.millis, anotherTime.millis)), + Time(scala.math.max(oneTime.millis, anotherTime.millis))) + } + + def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) + + def nextLocalDate(): LocalDate = nextInstant().atZone(ZoneOffset.UTC).toLocalDate + + def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) + + def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = + BigDecimal(multiplier * nextDouble, new MathContext(precision)) + + def oneOf[T](items: T*): T = oneOf(items.toSet) + + def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) + + def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) + + def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = + Array.fill(nextInt(maxLength, minLength))(generator) + + def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = + Seq.fill(nextInt(maxLength, minLength))(generator) + + def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = + Vector.fill(nextInt(maxLength, minLength))(generator) + + def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = + List.fill(nextInt(maxLength, minLength))(generator) + + def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = + seqOf(generator, maxLength, minLength).toSet + + def mapOf[K, V]( + keyGenerator: => K, + valueGenerator: => V, + maxLength: Int = DefaultMaxLength, + minLength: Int = 0): Map[K, V] = + seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap +} diff --git a/core-rest/src/main/scala/xyz/driver/core/json.scala b/core-rest/src/main/scala/xyz/driver/core/json.scala new file mode 100644 index 0000000..edc2347 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/json.scala @@ -0,0 +1,398 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.time.format.DateTimeFormatter +import java.time.{Instant, LocalDate} +import java.util.{TimeZone, UUID} + +import akka.http.scaladsl.unmarshalling.Unmarshaller +import com.neovisionaries.i18n.{CountryCode, CurrencyCode} +import enumeratum._ +import eu.timepit.refined.api.{Refined, Validate} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import spray.json._ +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.date.{Date, DayOfWeek, Month} +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} +import xyz.driver.core.rest.errors._ +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.reflect.runtime.universe._ +import scala.reflect.{ClassTag, classTag} +import scala.util.Try +import scala.util.control.NonFatal + +object json extends PathMatchers with Unmarshallers { + import DefaultJsonProtocol._ + + implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { + def write(id: Id[T]) = JsString(id.value) + + def read(value: JsValue): Id[T] = value match { + case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) + case JsString(id) => Id[T](id) + case _ => throw DeserializationException("Id expects string") + } + } + + implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F], convert: F => F @@ T = null): JsonFormat[F @@ T] = + new JsonFormat[F @@ T] { + import tagging._ + + private val transformReadValue = Option(convert).getOrElse((_: F).tagged[T]) + + override def write(obj: F @@ T): JsValue = underlying.write(obj) + + override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) + } + + implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { + def write(name: Name[T]) = JsString(name.value) + + def read(value: JsValue): Name[T] = value match { + case JsString(name) => Name[T](name) + case _ => throw DeserializationException("Name expects string") + } + } + + implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { + def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) + + def read(value: JsValue): Time = Time(instantFormat.read(value)) + } + + implicit val instantFormat: JsonFormat[Instant] = new JsonFormat[Instant] { + def write(instant: Instant): JsValue = JsString(instant.toString) + + def read(value: JsValue): Instant = value match { + case JsObject(fields) => + fields + .get("timestamp") + .flatMap { + case JsNumber(millis) => Some(Instant.ofEpochMilli(millis.longValue())) + case _ => None + } + .getOrElse(deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}")) + case JsNumber(millis) => Instant.ofEpochMilli(millis.longValue()) + case JsString(str) => + try Instant.parse(str) + catch { case NonFatal(_) => deserializationError(s"Instant expects ISO timestamp but got $str") } + case _ => deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}") + } + } + + implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { + private val formatter = TimeOfDay.getFormatter + def read(json: JsValue): java.time.LocalTime = json match { + case JsString(chars) => + java.time.LocalTime.parse(chars) + case _ => deserializationError(s"Expected time string got ${json.toString}") + } + + def write(obj: java.time.LocalTime): JsValue = { + JsString(obj.format(formatter)) + } + } + + implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { + override def write(obj: TimeZone): JsValue = { + JsString(obj.getID()) + } + + override def read(json: JsValue): TimeZone = json match { + case JsString(chars) => + java.util.TimeZone.getTimeZone(chars) + case _ => deserializationError(s"Expected time zone string got ${json.toString}") + } + } + + implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) + + implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) + + implicit val dateFormat = new RootJsonFormat[Date] { + def write(date: Date) = JsString(date.toString) + def read(value: JsValue): Date = value match { + case JsString(dateString) => + Date + .fromString(dateString) + .getOrElse( + throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) + case _ => throw DeserializationException(s"Date expects a string, but got $value.") + } + } + + implicit val localDateFormat = new RootJsonFormat[LocalDate] { + val format = DateTimeFormatter.ISO_LOCAL_DATE + + def write(date: LocalDate): JsValue = JsString(date.format(format)) + def read(value: JsValue): LocalDate = value match { + case JsString(dateString) => + try LocalDate.parse(dateString, format) + catch { + case NonFatal(_) => + throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.") + } + case _ => + throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got ${value.compactPrint}.") + } + } + + implicit val monthFormat = new RootJsonFormat[Month] { + def write(month: Month) = JsNumber(month) + def read(value: JsValue): Month = value match { + case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) + case _ => throw DeserializationException("Expected a number from 0 to 11") + } + } + + implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { + def write(revision: Revision[T]) = JsString(revision.id.toString) + + def read(value: JsValue): Revision[T] = value match { + case JsString(revision) => Revision[T](revision) + case _ => throw DeserializationException("Revision expects uuid string") + } + } + + implicit val base64Format = new RootJsonFormat[Base64] { + def write(base64Value: Base64) = JsString(base64Value.value) + + def read(value: JsValue): Base64 = value match { + case JsString(base64Value) => Base64(base64Value) + case _ => throw DeserializationException("Base64 format expects string") + } + } + + implicit val emailFormat = new RootJsonFormat[Email] { + def write(email: Email) = JsString(email.username + "@" + email.domain) + def read(json: JsValue): Email = json match { + + case JsString(value) => + Email.parse(value).getOrElse { + deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) + } + + case _ => + deserializationError("Expected string as Email, but got " + json.toString) + } + } + + implicit object phoneNumberFormat extends RootJsonFormat[PhoneNumber] { + + private val basicFormat = jsonFormat3(PhoneNumber.apply) + + def write(obj: PhoneNumber): JsValue = basicFormat.write(obj) + + def read(json: JsValue): PhoneNumber = { + val maybePhone = json match { + case JsString(number) => PhoneNumber.parse(number) + case obj: JsObject => PhoneNumber.parse(basicFormat.read(obj).toString) + case _ => None + } + maybePhone.getOrElse(deserializationError("Invalid phone number")) + } + } + + implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { + override def read(json: JsValue): AuthCredentials = { + json match { + case JsObject(fields) => + val emailField = fields.get("email") + val identifierField = fields.get("identifier") + val passwordField = fields.get("password") + + (emailField, identifierField, passwordField) match { + case (_, _, None) => + deserializationError("password field must be set") + case (Some(JsString(em)), _, Some(JsString(pw))) => + val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) + AuthCredentials(email.toString, pw) + case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) + case (None, None, _) => deserializationError("identifier must be provided") + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") + } + } + + override def write(obj: AuthCredentials): JsValue = JsObject( + "identifier" -> JsString(obj.identifier), + "password" -> JsString(obj.password) + ) + } + + implicit object inetAddressFormat extends JsonFormat[InetAddress] { + override def read(json: JsValue): InetAddress = json match { + case JsString(ipString) => + Try(InetAddress.getByName(ipString)) + .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) + case _ => deserializationError(s"Expected string for IP Address, got $json") + } + + override def write(obj: InetAddress): JsValue = + JsString(obj.getHostAddress) + } + + implicit val countryCodeFormat: JsonFormat[CountryCode] = javaEnumFormat[CountryCode] + + implicit val currencyCodeFormat: JsonFormat[CurrencyCode] = javaEnumFormat[CurrencyCode] + + object enumeratum { + + def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + + trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => + + implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) + + implicit val unmarshaller: Unmarshaller[String, T] = + Unmarshaller.strict { value => + enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) + } + } + + class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { + override def read(json: JsValue): T = json match { + case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + + override def write(obj: T): JsValue = JsString(obj.entryName) + } + + private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = + deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") + } + + class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { + private val map = mapping.toMap + + override def write(value: T): JsValue = { + map.find(_._2 == value).map(_._1) match { + case Some(name) => JsString(name) + case _ => serializationError(s"Value $value is not found in the mapping $map") + } + } + + override def read(json: JsValue): T = json match { + case JsString(name) => + map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) + case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) + } + } + + def javaEnumFormat[T <: java.lang.Enum[_]: ClassTag]: JsonFormat[T] = { + val values = classTag[T].runtimeClass.asInstanceOf[Class[T]].getEnumConstants + new EnumJsonFormat[T](values.map(v => v.name() -> v): _*) + } + + class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { + def write(valueClass: T) = JsNumber(writeValue(valueClass)) + def read(json: JsValue): T = json match { + case JsNumber(value) => create(value) + case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + class GadtJsonFormat[T: TypeTag]( + typeField: String, + typeValue: PartialFunction[T, String], + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) + extends RootJsonFormat[T] { + + def write(value: T): JsValue = { + + val valueType = typeValue.applyOrElse(value, { v: T => + deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) + }) + + val valueFormat = + jsonFormat.applyOrElse(valueType, { f: String => + deserializationError(s"No Json format for this type of $valueType") + }) + + valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { + case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) + case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") + } + } + + def read(json: JsValue): T = json match { + case JsObject(fields) => + val valueJson = JsObject(fields.filterNot(_._1 == typeField)) + fields(typeField) match { + case JsString(valueType) => + val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + }) + valueFormat.read(valueJson) + case _ => + deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") + } + case _ => + deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) + } + } + + object GadtJsonFormat { + + def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( + jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { + + new GadtJsonFormat[T](typeField, typeValue, jsonFormat) + } + } + + /** + * Provides the JsonFormat for the Refined types provided by the Refined library. + * + * @see https://github.com/fthomas/refined + */ + implicit def refinedJsonFormat[T, Predicate]( + implicit valueFormat: JsonFormat[T], + validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = + new JsonFormat[Refined[T, Predicate]] { + def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) + def read(value: JsValue): T Refined Predicate = { + refineV[Predicate](valueFormat.read(value))(validate) match { + case Right(refinedValue) => refinedValue + case Left(refinementError) => deserializationError(refinementError) + } + } + } + + implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = + new RootJsonFormat[NonEmptyName[T]] { + def write(name: NonEmptyName[T]) = JsString(name.value.value) + + def read(value: JsValue): NonEmptyName[T] = + NonEmptyName[T](nonEmptyStringFormat.read(value)) + } + + implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = + GadtJsonFormat.create[ServiceException]("type") { + case _: InvalidInputException => "InvalidInputException" + case _: InvalidActionException => "InvalidActionException" + case _: UnauthorizedException => "UnauthorizedException" + case _: ResourceNotFoundException => "ResourceNotFoundException" + case _: ExternalServiceException => "ExternalServiceException" + case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" + case _: DatabaseException => "DatabaseException" + } { + case "InvalidInputException" => jsonFormat(InvalidInputException, "message") + case "InvalidActionException" => jsonFormat(InvalidActionException, "message") + case "UnauthorizedException" => jsonFormat(UnauthorizedException, "message") + case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") + case "ExternalServiceException" => + jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") + case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") + case "DatabaseException" => jsonFormat(DatabaseException, "message") + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala b/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala new file mode 100644 index 0000000..87946e4 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala @@ -0,0 +1,11 @@ +package xyz.driver.core +package rest + +class DnsDiscovery(transport: HttpRestServiceTransport, overrides: Map[String, String]) { + + def discover[A](implicit descriptor: ServiceDescriptor[A]): A = { + val url = overrides.getOrElse(descriptor.name, s"https://${descriptor.name}") + descriptor.connect(transport, url) + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala new file mode 100644 index 0000000..911e306 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/DriverRoute.scala @@ -0,0 +1,122 @@ +package xyz.driver.core.rest + +import java.sql.SQLException + +import akka.http.scaladsl.model.headers.CacheDirectives.`no-cache` +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{StatusCodes, _} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import com.typesafe.scalalogging.Logger +import org.slf4j.MDC +import xyz.driver.core.rest +import xyz.driver.core.rest.errors._ + +import scala.compat.Platform.ConcurrentModificationException + +trait DriverRoute { + def log: Logger + + def route: Route + + def routeWithDefaults: Route = { + (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { + route + } + } + + protected def defaultResponseHeaders: Directive0 = { + extractRequest flatMap { request => + // Needs to happen before any request processing, so all the log messages + // associated with processing of this request are having this `trackingId` + val trackingId = rest.extractTrackingId(request) + val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) + MDC.put("trackingId", trackingId) + + respondWithHeaders(tracingHeader +: DriverRoute.DefaultHeaders: _*) + } + } + + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ + protected def exceptionHandler: PartialFunction[Throwable, Route] = { + case serviceException: ServiceException => + serviceExceptionHandler(serviceException) + + case is: IllegalStateException => + ctx => + log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) + errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) + + case cm: ConcurrentModificationException => + ctx => + log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) + errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( + ctx) + + case se: SQLException => + ctx => + log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) + errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) + + case t: Exception => + ctx => + log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) + errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) + } + + protected def serviceExceptionHandler(serviceException: ServiceException): Route = { + val statusCode = serviceException match { + case e: InvalidInputException => + log.info("Invalid client input error", e) + StatusCodes.BadRequest + case e: InvalidActionException => + log.info("Invalid client action error", e) + StatusCodes.Forbidden + case e: UnauthorizedException => + log.info("Unauthorized user error", e) + StatusCodes.Unauthorized + case e: ResourceNotFoundException => + log.info("Resource not found error", e) + StatusCodes.NotFound + case e: ExternalServiceException => + log.error("Error while calling another service", e) + StatusCodes.InternalServerError + case e: ExternalServiceTimeoutException => + log.error("Service timeout error", e) + StatusCodes.GatewayTimeout + case e: DatabaseException => + log.error("Database error", e) + StatusCodes.InternalServerError + } + + { (ctx: RequestContext) => + import xyz.driver.core.json.serviceExceptionFormat + val entity = + HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) + errorResponse(statusCode, entity, serviceException)(ctx) + } + } + + protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = + errorResponse(statusCode, HttpEntity(message), exception) + + protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { + complete(HttpResponse(statusCode, entity = entity)) + } + +} + +object DriverRoute { + val DefaultHeaders: List[HttpHeader] = List( + // This header will eliminate the risk of envoy trying to reuse a connection + // that already timed out on the server side by completely rejecting keep-alive + Connection("close"), + // These 2 headers are the simplest way to prevent IE from caching GET requests + RawHeader("Pragma", "no-cache"), + `Cache-Control`(List(`no-cache`(Nil))) + ) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala new file mode 100644 index 0000000..e31635b --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala @@ -0,0 +1,103 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.TcpIdleTimeoutException +import org.slf4j.MDC +import xyz.driver.core.Name +import xyz.driver.core.reporting.Reporter +import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +class HttpRestServiceTransport( + applicationName: Name[App], + applicationVersion: String, + val actorSystem: ActorSystem, + val executionContext: ExecutionContext, + reporter: Reporter) + extends ServiceTransport { + + protected implicit val execution: ExecutionContext = executionContext + + protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { + val tags = Map( + // open tracing semantic tags + "span.kind" -> "client", + "service" -> applicationName.value, + "http.url" -> requestStub.uri.toString, + "http.method" -> requestStub.method.value, + "peer.hostname" -> requestStub.uri.authority.host.toString, + // google's tracing console provides extra search features if we define these tags + "/http/path" -> requestStub.uri.path.toString, + "/http/method" -> requestStub.method.value.toString, + "/http/url" -> requestStub.uri.toString + ) + reporter.traceAsync(s"http_call_rpc", tags) { implicit span => + val requestTime = System.currentTimeMillis() + + val request = requestStub + .withHeaders(context.contextHeaders.toSeq.map { + case (ContextHeaders.TrackingIdHeader, _) => + RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) + case (ContextHeaders.StacktraceHeader, _) => + RawHeader( + ContextHeaders.StacktraceHeader, + Option(MDC.get("stack")) + .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) + .getOrElse("")) + case (header, headerValue) => RawHeader(header, headerValue) + }: _*) + + reporter.debug(s"Sending request to ${request.method} ${request.uri}") + + val response = httpClient.makeRequest(request) + + response.onComplete { + case Success(r) => + val responseLatency = System.currentTimeMillis() - requestTime + reporter.debug( + s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") + + case Failure(t: Throwable) => + val responseLatency = System.currentTimeMillis() - requestTime + reporter.warn( + s"Failed to receive response from ${request.method.value} ${request.uri} in $responseLatency ms", + t) + }(executionContext) + + response.recoverWith { + case _: TcpIdleTimeoutException => + val serviceCalled = s"${requestStub.method.value} ${requestStub.uri}" + Future.failed(ExternalServiceTimeoutException(serviceCalled)) + case t: Throwable => Future.failed(t) + } + }(context.spanContext) + } + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { + + sendRequestGetResponse(context)(requestStub) flatMap { response => + if (response.status == StatusCodes.NotFound) { + Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) + } else if (response.status.isFailure()) { + val serviceCalled = s"${requestStub.method} ${requestStub.uri}" + Unmarshal(response.entity).to[String] flatMap { errorString => + import spray.json._ + import xyz.driver.core.json._ + val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption + Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) + } + } else { + Future.successful(Unmarshal(response.entity)) + } + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala new file mode 100644 index 0000000..f33bf9d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala @@ -0,0 +1,104 @@ +package xyz.driver.core.rest + +import akka.http.javadsl.server.Rejections +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} +import spray.json._ + +import scala.concurrent.Future +import scala.util.{Failure, Success, Try} + +trait PatchDirectives extends Directives with SprayJsonSupport { + + /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + val `application/merge-patch+json`: MediaType.WithFixedCharset = + MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) + + /** Wraps a JSON value that represents a patch. + * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ + case class PatchValue(value: JsValue) { + + /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ + def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) + } + + /** Witness that the given patch may be applied to an original domain value. + * @tparam A type of the domain value + * @param patch the patch that may be applied to a domain value + * @param format a JSON format that enables serialization and deserialization of a domain value */ + case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { + + /** Applies the patch to a given domain object. The result will be a combination + * of the original value, updates with the fields specified in this witness' patch. */ + def applyTo(original: A): A = { + val serialized = format.write(original) + val merged = patch.applyTo(serialized) + val deserialized = format.read(merged) + deserialized + } + } + + implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = + Unmarshaller.byteStringUnmarshaller + .andThen(sprayJsValueByteStringUnmarshaller) + .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) + .map(js => PatchValue(js)) + + implicit def patchableUnmarshaller[A]( + implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], + format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { + patchUnmarshaller.map(patch => Patchable[A](patch, format)) + } + + protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { + JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => + val oldValue = oldObj.fields.getOrElse(key, JsNull) + val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) + key -> newValue + })(collection.breakOut): _*) + } + + protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { + def mergeError(typ: String): Nothing = + deserializationError(s"Expected $typ value, got $newValue") + + if (maxLevels.exists(_ < 0)) oldValue + else { + (oldValue, newValue) match { + case (_: JsString, newString @ (JsString(_) | JsNull)) => newString + case (_: JsString, _) => mergeError("string") + case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber + case (_: JsNumber, _) => mergeError("number") + case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool + case (_: JsBoolean, _) => mergeError("boolean") + case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr + case (_: JsArray, _) => mergeError("array") + case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) + case (_: JsObject, JsNull) => JsNull + case (_: JsObject, _) => mergeError("object") + case (JsNull, _) => newValue + } + } + } + + def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = + Directive { inner => requestCtx => + onSuccess(retrieve)({ + case Some(oldT) => + Try(patchable.applyTo(oldT)) + .transform[Route]( + mergedT => scala.util.Success(inner(Tuple1(mergedT))), { + case jsonException: DeserializationException => + Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) + case t => Failure(t) + } + ) + .get // intentionally re-throw all other errors + case None => reject() + })(requestCtx) + } +} + +object PatchDirectives extends PatchDirectives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala new file mode 100644 index 0000000..2854257 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala @@ -0,0 +1,67 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.scaladsl.{Keep, Sink, Source} +import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} +import xyz.driver.core.Name + +import scala.concurrent.{ExecutionContext, Future, Promise} +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +class PooledHttpClient( + baseUri: Uri, + applicationName: Name[App], + applicationVersion: String, + requestRateLimit: Int = 64, + requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) + extends HttpClient { + + private val host = baseUri.authority.host.toString() + private val port = baseUri.effectivePort + private val scheme = baseUri.scheme + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + private val pool = if (scheme.equalsIgnoreCase("https")) { + Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } else { + Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) + } + + private val queue = Source + .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) + .via(pool) + .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) + .toMat(Sink.foreach({ + case ((Success(resp), p)) => p.success(resp) + case ((Failure(e), p)) => p.failure(e) + }))(Keep.left) + .run + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + val responsePromise = Promise[HttpResponse]() + + queue.offer(request -> responsePromise).flatMap { + case QueueOfferResult.Enqueued => + responsePromise.future + case QueueOfferResult.Dropped => + Future.failed(new Exception(s"Request queue to the host $host is overflown")) + case QueueOfferResult.Failure(ex) => + Future.failed(ex) + case QueueOfferResult.QueueClosed => + Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala new file mode 100644 index 0000000..c0e9f99 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala @@ -0,0 +1,26 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} +import com.typesafe.config.Config +import xyz.driver.core.Name + +import scala.concurrent.ExecutionContext + +trait ProxyRoute extends DriverRoute { + implicit val executionContext: ExecutionContext + val config: Config + val httpClient: HttpClient + + protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => + val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") + val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") + + val originalUri = ctx.request.uri + val originalRequest = ctx.request + + val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) + val newRequest = originalRequest.withUri(newUri) + + httpClient.makeRequest(newRequest).map(RouteResult.Complete) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala b/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala new file mode 100644 index 0000000..09d98b8 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/RestService.scala @@ -0,0 +1,86 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model._ +import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} +import akka.stream.Materializer + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.{ListT, OptionT} + +trait RestService extends Service { + + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json._ + + protected implicit val exec: ExecutionContext + protected implicit val materializer: Materializer + + implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { + def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = + if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] + } + + protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = + OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) + + protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = + OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) + + protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( + implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = + ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) + + protected def jsonEntity(json: JsValue): RequestEntity = + HttpEntity(ContentTypes.`application/json`, json.compactPrint) + + protected def mergePatchJsonEntity(json: JsValue): RequestEntity = + HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) + + protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) + + protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) + + protected def postJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) + + protected def putJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) + + protected def patchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) + + protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = + HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) + + protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = + HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) + + protected def endpointUri(baseUri: Uri, path: String): Uri = + baseUri.withPath(Uri.Path(path)) + + protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = + baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) + + protected def responseToListResponse[T: JsonFormat](pagination: Option[Pagination])( + response: HttpResponse): Future[ListResponse[T]] = { + import DefaultJsonProtocol._ + val resourceCount = response.headers + .find(_.name() equalsIgnoreCase ContextHeaders.ResourceCount) + .map(_.value().toInt) + .getOrElse(0) + val meta = ListResponse.Meta(resourceCount, pagination.getOrElse(Pagination(resourceCount max 1, 1))) + Unmarshal(response.entity).to[List[T]].map(ListResponse(_, meta)) + } + + protected def responseToListResponse[T: JsonFormat](pagination: Pagination)( + response: HttpResponse): Future[ListResponse[T]] = responseToListResponse(Some(pagination))(response) +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala b/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala new file mode 100644 index 0000000..646fae8 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala @@ -0,0 +1,16 @@ +package xyz.driver.core +package rest +import scala.annotation.implicitNotFound + +@implicitNotFound( + "Don't know how to communicate with service ${S}. Make sure an implicit ServiceDescriptor is" + + "available. A good place to put one is in the service's companion object.") +trait ServiceDescriptor[S] { + + /** The service's name. Must be unique among all services. */ + def name: String + + /** Get an instance of the service. */ + def connect(transport: HttpRestServiceTransport, url: String): S + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala new file mode 100644 index 0000000..964a5a2 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala @@ -0,0 +1,29 @@ +package xyz.driver.core.rest + +import akka.actor.ActorSystem +import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers.`User-Agent` +import akka.http.scaladsl.model.{HttpRequest, HttpResponse} +import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} +import akka.stream.ActorMaterializer +import xyz.driver.core.Name + +import scala.concurrent.Future + +class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) + extends HttpClient { + + protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) + private val client = Http()(actorSystem) + + private val clientConnectionSettings: ClientConnectionSettings = + ClientConnectionSettings(actorSystem).withUserAgentHeader( + Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) + + private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) + .withConnectionSettings(clientConnectionSettings) + + def makeRequest(request: HttpRequest): Future[HttpResponse] = { + client.singleRequest(request, settings = connectionPoolSettings) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala b/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala new file mode 100644 index 0000000..5ceac54 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/Swagger.scala @@ -0,0 +1,144 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} +import akka.http.scaladsl.server.Route +import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile +import akka.stream.ActorAttributes +import akka.stream.scaladsl.{Framing, StreamConverters} +import akka.util.ByteString +import com.github.swagger.akka.SwaggerHttpService +import com.github.swagger.akka.model._ +import com.typesafe.config.Config +import com.typesafe.scalalogging.Logger +import io.swagger.models.Scheme +import io.swagger.models.auth.{ApiKeyAuthDefinition, In} +import io.swagger.util.Json + +import scala.util.control.NonFatal + +class Swagger( + override val host: String, + accessSchemes: List[String], + version: String, + override val apiClasses: Set[Class[_]], + val config: Config, + val logger: Logger) + extends SwaggerHttpService { + + override val schemes = accessSchemes.map { s => + Scheme.forValue(s) + } + + // Note that the reason for overriding this is a subtle chain of causality: + // + // 1. Some of our endpoints require a single trailing slash and will not + // function if it is omitted + // 2. Swagger omits trailing slashes in its generated api doc + // 3. To work around that, a space is added after the trailing slash in the + // swagger Path annotations + // 4. This space is removed manually in the code below + // + // TODO: Ideally we'd like to drop this custom override and fix the issue in + // 1, by dropping the slash requirement and accepting api endpoints with and + // without trailing slashes. This will require inspecting and potentially + // fixing all service endpoints. + override def generateSwaggerJson: String = { + import io.swagger.models.{Swagger => JSwagger} + + import scala.collection.JavaConverters._ + try { + val swagger: JSwagger = reader.read(apiClasses.asJava) + + val paths = if (swagger.getPaths == null) { + Map.empty + } else { + swagger.getPaths.asScala + } + + // Removing trailing spaces + val fixedPaths = paths.map { + case (key, path) => + key.trim -> path + } + + swagger.setPaths(fixedPaths.asJava) + + Json.pretty().writeValueAsString(swagger) + } catch { + case NonFatal(t) => + logger.error("Issue with creating swagger.json", t) + throw t + } + } + + override val securitySchemeDefinitions = Map( + "token" -> { + val definition = new ApiKeyAuthDefinition("Authorization", In.HEADER) + definition.setDescription("Authentication token") + definition + } + ) + + override val basePath: String = config.getString("swagger.basePath") + override val apiDocsPath: String = config.getString("swagger.docsPath") + + override val info = Info( + config.getString("swagger.apiInfo.description"), + version, + config.getString("swagger.apiInfo.title"), + config.getString("swagger.apiInfo.termsOfServiceUrl"), + contact = Some( + Contact( + config.getString("swagger.apiInfo.contact.name"), + config.getString("swagger.apiInfo.contact.url"), + config.getString("swagger.apiInfo.contact.email") + )), + license = Some( + License( + config.getString("swagger.apiInfo.license"), + config.getString("swagger.apiInfo.licenseUrl") + )), + vendorExtensions = Map.empty[String, AnyRef] + ) + + /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ + private def getTemplatedResource( + resourceName: String, + contentType: ContentType, + substitution: (String, String)): Route = get { + Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { + case Some(ResourceFile(url, length @ _, _)) => + extractSettings { settings => + val stream = StreamConverters + .fromInputStream(() => url.openStream()) + .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) + .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) + .map { line => + line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) + } + .map(line => ByteString(line + "\n")) + complete( + HttpEntity(contentType, stream) + ) + } + case None => reject + } + } + + def swaggerUI: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui") + + def swaggerUINew: Route = + pathEndOrSingleSlash { + getTemplatedResource( + "swagger-ui-dist/index.html", + ContentTypes.`text/html(UTF-8)`, + "title" -> config.getString("swagger.apiInfo.title")) + } ~ getFromResourceDirectory("swagger-ui-dist") + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala new file mode 100644 index 0000000..5007774 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala @@ -0,0 +1,14 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val permissionsMap = permissions.map(_ -> true).toMap + Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala new file mode 100644 index 0000000..e1a94e1 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala @@ -0,0 +1,75 @@ +package xyz.driver.core.rest.auth + +import akka.http.scaladsl.server.directives.Credentials +import com.typesafe.scalalogging.Logger +import scalaz.OptionT +import xyz.driver.core.auth.{AuthToken, Permission, User} +import xyz.driver.core.rest.errors.{ExternalServiceException, UnauthorizedException} +import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ContextHeaders, ServiceRequestContext, serviceContext} + +import scala.concurrent.{ExecutionContext, Future} + +abstract class AuthProvider[U <: User]( + val authorization: Authorization[U], + log: Logger, + val realm: String +)(implicit execution: ExecutionContext) { + + import akka.http.scaladsl.server._ + import Directives.{authorize => akkaAuthorize, _} + + def this(authorization: Authorization[U], log: Logger)(implicit executionContext: ExecutionContext) = + this(authorization, log, "driver.xyz") + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param ctx set of request values which can be relevant to authenticate user + * @return authenticated user + */ + def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] + + protected def authenticator(context: ServiceRequestContext): AsyncAuthenticator[U] = { + case Credentials.Missing => + log.info(s"Request (${context.trackingId}) missing authentication credentials") + Future.successful(None) + case Credentials.Provided(authToken) => + authenticatedUser(context.withAuthToken(AuthToken(authToken))).run.recover({ + case ExternalServiceException(_, _, Some(UnauthorizedException(_))) => None + }) + } + + /** + * Verifies that a user agent is properly authenticated, and (optionally) authorized with the specified permissions + */ + def authorize( + context: ServiceRequestContext, + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authenticateOAuth2Async[U](realm, authenticator(context)) flatMap { authenticatedUser => + val authCtx = context.withAuthenticatedUser(context.authToken.get, authenticatedUser) + onSuccess(authorization.userHasPermissions(authenticatedUser, permissions)(authCtx)) flatMap { + case AuthorizationResult(authorized, token) => + val allAuthorized = permissions.forall(authorized.getOrElse(_, false)) + akkaAuthorize(allAuthorized) tflatMap { _ => + val cachedPermissionsCtx = token.fold(authCtx)(authCtx.withPermissionsToken) + provide(cachedPermissionsCtx) + } + } + } + } + + /** + * Verifies if request is authenticated and authorized to have `permissions` + */ + def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + serviceContext flatMap (authorize(_, permissions: _*)) + } +} + +object AuthProvider { + val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader + val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader + val SetAuthenticationTokenHeader: String = "set-authorization" + val SetPermissionsTokenHeader: String = "set-permissions" +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala new file mode 100644 index 0000000..1a5e9be --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala @@ -0,0 +1,11 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future + +trait Authorization[U <: User] { + def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala new file mode 100644 index 0000000..efe28c9 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala @@ -0,0 +1,22 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, PermissionsToken} + +import scalaz.Scalaz.mapMonoid +import scalaz.Semigroup +import scalaz.syntax.semigroup._ + +final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) +object AuthorizationResult { + val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) + + implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { + private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) + private implicit val permissionsTokenSemigroup = + Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) + + override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { + AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala new file mode 100644 index 0000000..66de4ef --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala @@ -0,0 +1,55 @@ +package xyz.driver.core.rest.auth + +import java.nio.file.{Files, Path} +import java.security.{KeyFactory, PublicKey} +import java.security.spec.X509EncodedKeySpec + +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.Future +import scalaz.syntax.std.boolean._ + +class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + import spray.json._ + + def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = + tokenObject.fields.get("permissions").collect { + case JsObject(fields) => + fields.collect { + case (key, JsBoolean(value)) => key -> value + } + } + + val result = for { + token <- ctx.permissionsToken + jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption + jwtJson = jwt.parseJson.asJsObject + + // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None + _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) + _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) + + permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) + + authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap + } yield AuthorizationResult(authorized, Some(token)) + + Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) + } +} + +object CachedTokenAuthorization { + def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { + lazy val publicKey: PublicKey = { + val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim + val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) + val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) + KeyFactory.getInstance("RSA").generatePublic(spec) + } + new CachedTokenAuthorization[U](publicKey, issuer) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala new file mode 100644 index 0000000..131e7fc --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.auth + +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.ServiceRequestContext + +import scala.concurrent.{ExecutionContext, Future} +import scalaz.Scalaz.{futureInstance, listInstance} +import scalaz.syntax.semigroup._ +import scalaz.syntax.traverse._ + +class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) + extends Authorization[U] { + + override def userHasPermissions(user: U, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = + permissions.forall(permissionsMap.getOrElse(_, false)) + + authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { + (authResult, authorization) => + if (allAuthorized(authResult.authorized)) Future.successful(authResult) + else { + authorization.userHasPermissions(user, permissions).map(authResult |+| _) + } + } + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala new file mode 100644 index 0000000..ff3424d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala @@ -0,0 +1,19 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} +import xyz.driver.core.auth.{Permission, User} +import xyz.driver.core.rest.auth.AuthProvider + +/** Authentication and authorization directives. */ +trait AuthDirectives extends AkkaDirectives { + + /** Authenticate a user based on service request headers and check if they have all given permissions. */ + def authenticateAndAuthorize[U <: User]( + authProvider: AuthProvider[U], + permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { + authProvider.authorize(permissions: _*) + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala new file mode 100644 index 0000000..5a6bbfd --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala @@ -0,0 +1,72 @@ +package xyz.driver.core +package rest +package directives + +import akka.http.scaladsl.model.HttpMethods._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpResponse, StatusCodes} +import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} + +/** Directives to handle Cross-Origin Resource Sharing (CORS). */ +trait CorsDirectives extends AkkaDirectives { + + /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request + * origin. + * + * In a microservice environment, it can be difficult to know in advance the exact origin + * from which requests may be issued [1]. For example, the request may come from a web page served from + * any of the services, on any namespace or from other documentation sites. In general, only a set + * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that + * implement CORS require exact specification of allowed origins, including full host name and scheme, + * in order to send credentials and headers with requests to other origins. + * + * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by + * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. + * + * Note that the simplicity of this implementation comes with two notable drawbacks: + * + * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. + * + * - Allowed methods and headers can not be customized on a per-request basis. All standard + * HTTP methods are allowed, and allowed headers are specified for all inner routes. + * + * This handler is not suited for cases where more fine-grained control of responses is required. + * + * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through + * a common gateway. + * + * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed + * origins. + * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. + * @param inner Route into which CORS headers will be injected. + */ + def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { + optionalHeaderValueByType[Origin](()) { maybeOrigin => + val allowedOrigins: HttpOriginRange = maybeOrigin match { + // Note that this is not a security issue: the client will never send credentials if the allowed + // origin is set to *. This case allows us to deal with clients that do not send an origin header. + case None => HttpOriginRange.* + case Some(requestOrigin) => + val allowedOrigin = requestOrigin.origins.find(origin => + allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) + allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) + } + + respondWithHeaders( + `Access-Control-Allow-Origin`.forRange(allowedOrigins), + `Access-Control-Allow-Credentials`(true), + `Access-Control-Allow-Headers`(allowedHeaders: _*), + `Access-Control-Expose-Headers`(allowedHeaders: _*) + ) { + options { // options is used during preflight check + complete( + HttpResponse(StatusCodes.OK) + .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) + } ~ inner // in case of non-preflight check we don't do anything special + } + } + } + +} + +object CorsDirectives extends CorsDirectives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala new file mode 100644 index 0000000..0cd4ef1 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Directives.scala @@ -0,0 +1,6 @@ +package xyz.driver.core +package rest +package directives + +trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers +object Directives extends Directives diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala new file mode 100644 index 0000000..218c9ae --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala @@ -0,0 +1,85 @@ +package xyz.driver.core +package rest +package directives + +import java.time.Instant +import java.util.UUID + +import akka.http.scaladsl.model.Uri.Path +import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} +import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.refineV +import xyz.driver.core.domain.PhoneNumber +import xyz.driver.core.time.Time + +import scala.util.control.NonFatal + +/** Akka-HTTP path matchers for custom core types. */ +trait PathMatchers { + + private def UuidInPath[T]: PathMatcher1[Id[T]] = + AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) + + def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) + case _ => Unmatched + } + } + + def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) + case _ => Unmatched + } + } + + private def timestampInPath: PathMatcher1[Long] = + PathMatcher("""[+-]?\d*""".r) flatMap { string => + try Some(string.toLong) + catch { case _: IllegalArgumentException => None } + } + + def InstantInPath: PathMatcher1[Instant] = + new PathMatcher1[Instant] { + def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { + case Path.Segment(head, tail) => + try Matched(tail, Tuple1(Instant.parse(head))) + catch { + case NonFatal(_) => Unmatched + } + case _ => Unmatched + } + } | timestampInPath.map(Instant.ofEpochMilli) + + def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) + + def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + refineV[NonEmpty](segment) match { + case Left(_) => Unmatched + case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) + } + case _ => Unmatched + } + } + + def RevisionInPath[T]: PathMatcher1[Revision[T]] = + PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => + Some(Revision[T](string)) + } + + def PhoneInPath: PathMatcher1[PhoneNumber] = new PathMatcher1[PhoneNumber] { + def apply(path: Path) = path match { + case Path.Segment(segment, tail) => + PhoneNumber + .parse(segment) + .map(parsed => Matched(tail, Tuple1(parsed))) + .getOrElse(Unmatched) + case _ => Unmatched + } + } + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala new file mode 100644 index 0000000..6c45d15 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala @@ -0,0 +1,40 @@ +package xyz.driver.core +package rest +package directives + +import java.util.UUID + +import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} +import akka.http.scaladsl.unmarshalling.Unmarshaller +import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} + +/** Akka-HTTP unmarshallers for custom core types. */ +trait Unmarshallers { + + implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = + Unmarshaller.strict[String, Id[A]] { str => + Id[A](UUID.fromString(str).toString) + } + + implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = + Unmarshaller.firstOf( + Unmarshaller.strict((JsString(_: String)) andThen reader.read), + stringToValueUnmarshaller[T] + ) + + implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = + Unmarshaller.strict[String, Revision[T]](Revision[T]) + + val jsValueToStringMarshaller: Marshaller[JsValue, String] = + Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) + + def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = + jsValueToStringMarshaller.compose[T](jsonFormat.write) + + val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = + Unmarshaller.strict[String, JsValue](value => JsonParser(value)) + + def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = + stringToJsValueUnmarshaller.map[T](jsonFormat.read) + +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala new file mode 100644 index 0000000..f2962c9 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala @@ -0,0 +1,27 @@ +package xyz.driver.core.rest.errors + +sealed abstract class ServiceException(val message: String) extends Exception(message) + +final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) + +final case class InvalidActionException(override val message: String = "This action is not allowed") + extends ServiceException(message) + +final case class UnauthorizedException( + override val message: String = "The user's authentication credentials are invalid or missing") + extends ServiceException(message) + +final case class ResourceNotFoundException(override val message: String = "Resource not found") + extends ServiceException(message) + +final case class ExternalServiceException( + serviceName: String, + serviceMessage: String, + serviceException: Option[ServiceException]) + extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") + +final case class ExternalServiceTimeoutException(serviceName: String) + extends ServiceException(s"$serviceName took too long to respond") + +final case class DatabaseException(override val message: String = "Database access error") + extends ServiceException(message) diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala new file mode 100644 index 0000000..866476d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala @@ -0,0 +1,33 @@ +package xyz.driver.core +package rest +package headers + +import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} +import xyz.driver.core.reporting.SpanContext + +import scala.util.Try + +/** Encapsulates a trace context in an HTTP header for propagation across services. + * + * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) + * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is + * ignored. + */ +final case class Traceparent(spanContext: SpanContext) extends ModeledCustomHeader[Traceparent] { + override def renderInRequests = true + override def renderInResponses = true + override val companion: Traceparent.type = Traceparent + override def value: String = f"01-${spanContext.traceId}-${spanContext.spanId}-00" +} +object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { + override val name = "traceparent" + override def parse(value: String) = Try { + val Array(version, traceId, spanId, _) = value.split("-") + require( + version == "01", + s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") + new Traceparent( + new SpanContext(traceId, spanId) + ) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/package.scala b/core-rest/src/main/scala/xyz/driver/core/rest/package.scala new file mode 100644 index 0000000..34a4a9d --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/package.scala @@ -0,0 +1,323 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.unmarshalling.Unmarshal +import akka.stream.Materializer +import akka.stream.scaladsl.Flow +import akka.util.ByteString +import scalaz.Scalaz.{intInstance, stringInstance} +import scalaz.syntax.equal._ +import scalaz.{Functor, OptionT} +import xyz.driver.core.rest.auth.AuthProvider +import xyz.driver.core.rest.errors.ExternalServiceException +import xyz.driver.core.rest.headers.Traceparent +import xyz.driver.tracing.TracingDirectives + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.Try + +trait Service + +object Service + +trait HttpClient { + def makeRequest(request: HttpRequest): Future[HttpResponse] +} + +trait ServiceTransport { + + def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] + + def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( + implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] +} + +sealed trait SortingOrder +object SortingOrder { + case object Asc extends SortingOrder + case object Desc extends SortingOrder +} + +final case class SortingField(name: String, sortingOrder: SortingOrder) +final case class Sorting(sortingFields: Seq[SortingField]) + +final case class Pagination(pageSize: Int, pageNumber: Int) { + require(pageSize > 0, "Page size must be greater than zero") + require(pageNumber > 0, "Page number must be greater than zero") + + def offset: Int = pageSize * (pageNumber - 1) +} + +final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) + +object ListResponse { + + def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = + ListResponse( + items = items, + meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) + + final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) + + object Meta { + def apply(itemsCount: Int, pagination: Pagination): Meta = + Meta(itemsCount, pagination.pageNumber, pagination.pageSize) + } + +} + +object `package` { + + implicit class FutureExtensions[T](future: Future[T]) { + def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = + future.transform(identity, { + case ExternalServiceException(_, _, Some(e)) => e + case t: Throwable => t + }) + } + + implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { + def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( + implicit F: Functor[Future], + em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { + optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) + } + } + + object ContextHeaders { + val AuthenticationTokenHeader: String = "Authorization" + val PermissionsTokenHeader: String = "Permissions" + val AuthenticationHeaderPrefix: String = "Bearer" + val ClientFingerprintHeader: String = "X-Client-Fingerprint" + val TrackingIdHeader: String = "X-Trace" + val StacktraceHeader: String = "X-Stacktrace" + val OriginatingIpHeader: String = "X-Forwarded-For" + val ResourceCount: String = "X-Resource-Count" + val PageCount: String = "X-Page-Count" + val TraceHeaderName: String = TracingDirectives.TraceHeaderName + val SpanHeaderName: String = TracingDirectives.SpanHeaderName + } + + val AllowedHeaders: Seq[String] = + Seq( + "Origin", + "X-Requested-With", + "Content-Type", + "Content-Length", + "Accept", + "X-Trace", + "Access-Control-Allow-Methods", + "Access-Control-Allow-Origin", + "Access-Control-Allow-Headers", + "Server", + "Date", + ContextHeaders.ClientFingerprintHeader, + ContextHeaders.TrackingIdHeader, + ContextHeaders.TraceHeaderName, + ContextHeaders.SpanHeaderName, + ContextHeaders.StacktraceHeader, + ContextHeaders.AuthenticationTokenHeader, + ContextHeaders.OriginatingIpHeader, + ContextHeaders.ResourceCount, + ContextHeaders.PageCount, + "X-Frame-Options", + "X-Content-Type-Options", + "Strict-Transport-Security", + AuthProvider.SetAuthenticationTokenHeader, + AuthProvider.SetPermissionsTokenHeader, + "Traceparent" + ) + + def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = + `Access-Control-Allow-Origin`( + originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) + + def serviceContext: Directive1[ServiceRequestContext] = { + def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { + headers.map({ header => + if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( + ContextHeaders.AuthenticationHeaderPrefix)) { + Authorization(OAuth2BearerToken(header.value)) + } else header + })(collection.breakOut) + } + extractClientIP flatMap { remoteAddress => + mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => + extract(ctx => extractServiceContext(ctx.request, remoteAddress)) + } + } + } + + def respondWithCorsAllowedHeaders: Directive0 = { + respondWithHeaders( + List[HttpHeader]( + `Access-Control-Allow-Headers`(AllowedHeaders: _*), + `Access-Control-Expose-Headers`(AllowedHeaders: _*) + )) + } + + def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { + respondWithHeader { + `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) + } + } + + def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { + respondWithHeaders( + List[HttpHeader]( + Allow(methods.to[collection.immutable.Seq]), + `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) + )) + } + + def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = + new ServiceRequestContext( + extractTrackingId(request), + extractOriginatingIP(request, remoteAddress), + extractContextHeaders(request)) + + def extractTrackingId(request: HttpRequest): String = { + request.headers + .find(_.name === ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } + + def extractFingerprintHash(request: HttpRequest): Option[String] = { + request.headers + .find(_.name === ContextHeaders.ClientFingerprintHeader) + .map(_.value()) + } + + def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { + request.headers + .find(_.name === ContextHeaders.OriginatingIpHeader) + .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) + .orElse(remoteAddress.toOption) + } + + def extractStacktrace(request: HttpRequest): Array[String] = + request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") + + def extractContextHeaders(request: HttpRequest): Map[String, String] = { + request.headers + .filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || + h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || + h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || + h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || + h.name === Traceparent.name + } + .map { header => + if (header.name === ContextHeaders.AuthenticationTokenHeader) { + header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim + } else { + header.name -> header.value + } + } + .toMap + } + + private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { + @annotation.tailrec + def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { + val index = byteString.indexOf('/', from) + if (index === -1) descIndices.reverse + else { + val (init, tail) = byteString.splitAt(index) + if ((init endsWith "<") && (tail startsWith "/sc")) { + dirtyIndices(index + 1, index :: descIndices) + } else { + dirtyIndices(index + 1, descIndices) + } + } + } + + val indices = dirtyIndices(0, Nil) + + indices.headOption.fold(byteString) { head => + val builder = ByteString.newBuilder + builder ++= byteString.take(head) + + (indices :+ byteString.length).sliding(2).foreach { + case Seq(start, end) => + builder += ' ' + builder ++= byteString.slice(start, end) + case Seq(_) => // Should not match; sliding on at least 2 elements + assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") + } + builder.result + } + } + + val sanitizeRequestEntity: Directive0 = { + mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) + } + + val paginated: Directive1[Pagination] = + parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) + + private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = + (pageSizeOpt, pageNumberOpt) match { + case (Some(size), Some(number)) => Option(Pagination(size, number)) + case (None, None) => Option.empty[Pagination] + case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") + } + + val optionalPagination: Directive1[Option[Pagination]] = + parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) + + def paginationQuery(pagination: Pagination) = + Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) + + def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( + implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { + optionalPagination { pagination => + onSuccess(handler(pagination)) { + case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => + val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) + val headers = List( + RawHeader(ContextHeaders.ResourceCount, count.toString), + RawHeader(ContextHeaders.PageCount, pageCount.toString)) + + respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) + } + } + } + + private def extractSorting(sortingString: Option[String]): Sorting = { + val sortingFields = sortingString.fold(Seq.empty[SortingField])( + _.split(",") + .filter(_.length > 0) + .map { sortingParam => + if (sortingParam.startsWith("-")) { + SortingField(sortingParam.substring(1), SortingOrder.Desc) + } else { + val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam + SortingField(fieldName, SortingOrder.Asc) + } + } + .toSeq) + + Sorting(sortingFields) + } + + val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) + + def sortingQuery(sorting: Sorting): Seq[(String, String)] = { + val sortingString = sorting.sortingFields + .map { sortingField => + sortingField.sortingOrder match { + case SortingOrder.Asc => sortingField.name + case SortingOrder.Desc => s"-${sortingField.name}" + } + } + .mkString(",") + Seq("sort" -> sortingString) + } +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala new file mode 100644 index 0000000..55f1a2e --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala @@ -0,0 +1,24 @@ +package xyz.driver.core.rest + +import xyz.driver.core.Name + +trait ServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T +} + +trait SavingUsedServiceDiscovery { + private val usedServices = new scala.collection.mutable.HashSet[String]() + + def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { + usedServices += serviceName.value + } + + def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } +} + +class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { + + def discover[T <: Service](serviceName: Name[Service]): T = + throw new IllegalArgumentException(s"Service with name $serviceName is unknown") +} diff --git a/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala new file mode 100644 index 0000000..d2e4bc3 --- /dev/null +++ b/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala @@ -0,0 +1,87 @@ +package xyz.driver.core.rest + +import java.net.InetAddress + +import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} +import xyz.driver.core.generators +import scalaz.Scalaz.{mapEqual, stringInstance} +import scalaz.syntax.equal._ +import xyz.driver.core.reporting.SpanContext +import xyz.driver.core.rest.auth.AuthProvider +import xyz.driver.core.rest.headers.Traceparent + +import scala.util.Try + +class ServiceRequestContext( + val trackingId: String = generators.nextUuid().toString, + val originatingIp: Option[InetAddress] = None, + val contextHeaders: Map[String, String] = Map.empty[String, String]) { + def authToken: Option[AuthToken] = + contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) + + def permissionsToken: Option[PermissionsToken] = + contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) + + def withAuthToken(authToken: AuthToken): ServiceRequestContext = + new ServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) + ) + + def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), + user + ) + + override def hashCode(): Int = + Seq[Any](trackingId, originatingIp, contextHeaders) + .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) + + override def equals(obj: Any): Boolean = obj match { + case ctx: ServiceRequestContext => + trackingId === ctx.trackingId && + originatingIp == originatingIp && + contextHeaders === ctx.contextHeaders + case _ => false + } + + def spanContext: SpanContext = { + val validHeader = Try { + contextHeaders(Traceparent.name) + }.flatMap { value => + Traceparent.parse(value) + } + validHeader.map(_.spanContext).getOrElse(SpanContext.fresh()) + } + + override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" +} + +class AuthorizedServiceRequestContext[U <: User]( + override val trackingId: String = generators.nextUuid().toString, + override val originatingIp: Option[InetAddress] = None, + override val contextHeaders: Map[String, String] = Map.empty[String, String], + val authenticatedUser: U) + extends ServiceRequestContext { + + def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = + new AuthorizedServiceRequestContext[U]( + trackingId, + originatingIp, + contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), + authenticatedUser) + + override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser + case _ => false + } + + override def toString: String = + s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" +} diff --git a/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala new file mode 100644 index 0000000..2e772fb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/AuthTest.scala @@ -0,0 +1,165 @@ +package xyz.driver.core + +import akka.http.scaladsl.model.headers.{ + HttpChallenges, + OAuth2BearerToken, + RawHeader, + Authorization => AkkaAuthorization +} +import akka.http.scaladsl.server.Directives._ +import akka.http.scaladsl.server._ +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import pdi.jwt.{Jwt, JwtAlgorithm} +import xyz.driver.core.auth._ +import xyz.driver.core.domain.Email +import xyz.driver.core.logging._ +import xyz.driver.core.rest._ +import xyz.driver.core.rest.auth._ +import xyz.driver.core.time.Time + +import scala.concurrent.Future +import scalaz.OptionT + +class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { + + case object TestRoleAllowedPermission extends Permission + case object TestRoleAllowedByTokenPermission extends Permission + case object TestRoleNotAllowedPermission extends Permission + + val TestRole = Role(Id("1"), Name("testRole")) + + val (publicKey, privateKey) = { + import java.security.KeyPairGenerator + + val keygen = KeyPairGenerator.getInstance("RSA") + keygen.initialize(2048) + + val keyPair = keygen.generateKeyPair() + (keyPair.getPublic, keyPair.getPrivate) + } + + val basicAuthorization: Authorization[User] = new Authorization[User] { + + override def userHasPermissions(user: User, permissions: Seq[Permission])( + implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { + val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap + Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) + } + } + + val tokenIssuer = "users" + val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) + + val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) + + val authStatusService = new AuthProvider[User](authorization, NoLogger) { + override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = + OptionT.optionT[Future] { + if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { + Future.successful( + Some( + AuthTokenUserInfo( + Id[User]("1"), + Email("foo", "bar"), + emailVerified = true, + audience = "driver", + roles = Set(TestRole), + expirationTime = Time(1000000L) + ))) + } else { + Future.successful(Option.empty[User]) + } + } + } + + import authStatusService._ + + "'authorize' directive" should "throw error if auth token is not in the request" in { + + Get("/naive/attempt") ~> + authorize(TestRoleAllowedPermission) { user => + complete("Never going to be here") + } ~> + check { + // handled shouldBe false + rejections should contain( + AuthenticationFailedRejection( + AuthenticationFailedRejection.CredentialsMissing, + HttpChallenges.oAuth2(authStatusService.realm))) + } + } + + it should "throw error if authorized user does not have the requested permission" in { + + val referenceAuthToken = AuthToken("I am a test role's token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Post("/administration/attempt").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleNotAllowedPermission) { user => + complete("Never going to get here") + } ~> + check { + handled shouldBe false + rejections should contain(AuthorizationFailedRejection) + } + } + + it should "pass and retrieve the token to client code, if token is in request and user has permission" in { + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/valid/attempt/?a=2&b=5").addHeader( + referenceAuthHeader + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in { + val referenceAuthToken = AuthToken("unprefixed_token") + + Get("/valid/attempt/?a=2&b=5").addHeader( + RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value) + ) ~> + authorize(TestRoleAllowedPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized" + } + } + + it should "authorize permission found in permissions token" in { + import spray.json._ + + val claim = JsObject( + Map( + "iss" -> JsString(tokenIssuer), + "sub" -> JsString("1"), + "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) + )).prettyPrint + val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) + val referenceAuthToken = AuthToken("I am token") + val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) + + Get("/alic/attempt/?a=2&b=5") + .addHeader(referenceAuthHeader) + .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> + authorize(TestRoleAllowedByTokenPermission) { ctx => + complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") + } ~> + check { + handled shouldBe true + responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala new file mode 100644 index 0000000..7e740a4 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/GeneratorsTest.scala @@ -0,0 +1,264 @@ +package xyz.driver.core + +import org.scalatest.{Assertions, FlatSpec, Matchers} + +import scala.collection.immutable.IndexedSeq + +class GeneratorsTest extends FlatSpec with Matchers with Assertions { + import generators._ + + "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { + + val generatedId1 = nextId[String]() + val generatedId2 = nextId[String]() + val generatedId3 = nextId[Long]() + + generatedId1.length should be >= 0 + generatedId2.length should be >= 0 + generatedId3.length should be >= 0 + generatedId1 should not be generatedId2 + generatedId2 should !==(generatedId3) + } + + it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { + + val generatedLimitedId1 = nextId[String](5) + val generatedLimitedId2 = nextId[String](4) + val generatedLimitedId3 = nextId[Long](3) + + generatedLimitedId1.length should be >= 0 + generatedLimitedId1.length should be < 6 + generatedLimitedId2.length should be >= 0 + generatedLimitedId2.length should be < 5 + generatedLimitedId3.length should be >= 0 + generatedLimitedId3.length should be < 4 + generatedLimitedId1 should not be generatedLimitedId2 + generatedLimitedId2 should !==(generatedLimitedId3) + } + + it should "be able to generate com.drivergrp.core.Name names" in { + + Seq.fill(10)(nextName[String]()).distinct.size should be > 1 + nextName[String]().value.length should be >= 0 + + val fixedLengthName = nextName[String](10) + fixedLengthName.length should be <= 10 + assert(!fixedLengthName.value.exists(_.isControl)) + } + + it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { + + assert(nextNonEmptyName[String]().value.value.nonEmpty) + } + + it should "be able to generate proper UUIDs" in { + + nextUuid() should not be nextUuid() + nextUuid().toString.length should be(36) + } + + it should "be able to generate new Revisions" in { + + nextRevision[String]() should not be nextRevision[String]() + nextRevision[String]().id.length should be > 0 + } + + it should "be able to generate strings" in { + + nextString() should not be nextString() + nextString().length should be >= 0 + + val fixedLengthString = nextString(20) + fixedLengthString.length should be <= 20 + assert(!fixedLengthString.exists(_.isControl)) + } + + it should "be able to generate strings non-empty strings whic are non empty" in { + + assert(nextNonEmptyString().value.nonEmpty) + } + + it should "be able to generate options which are sometimes have values and sometimes not" in { + + val generatedOption = nextOption("2") + + generatedOption should not contain "1" + assert(generatedOption === Some("2") || generatedOption === None) + } + + it should "be able to generate a pair of two generated values" in { + + val constantPair = nextPair("foo", 1L) + constantPair._1 should be("foo") + constantPair._2 should be(1L) + + val generatedPair = nextPair(nextId[Int](), nextName[Int]()) + + generatedPair._1.length should be > 0 + generatedPair._2.length should be > 0 + + nextPair(nextId[Int](), nextName[Int]()) should not be + nextPair(nextId[Int](), nextName[Int]()) + } + + it should "be able to generate a triad of two generated values" in { + + val constantTriad = nextTriad("foo", "bar", 1L) + constantTriad._1 should be("foo") + constantTriad._2 should be("bar") + constantTriad._3 should be(1L) + + val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + + generatedTriad._1.length should be > 0 + generatedTriad._2.length should be > 0 + generatedTriad._3 should be >= BigDecimal(0.00) + + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be + nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) + } + + it should "be able to generate a time value" in { + + val generatedTime = nextTime() + val currentTime = System.currentTimeMillis() + + generatedTime.millis should be >= 0L + generatedTime.millis should be <= currentTime + } + + it should "be able to generate a time range value" in { + + val generatedTimeRange = nextTimeRange() + val currentTime = System.currentTimeMillis() + + generatedTimeRange.start.millis should be >= 0L + generatedTimeRange.start.millis should be <= currentTime + generatedTimeRange.end.millis should be >= 0L + generatedTimeRange.end.millis should be <= currentTime + generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis + } + + it should "be able to generate a BigDecimal value" in { + + val defaultGeneratedBigDecimal = nextBigDecimal() + + defaultGeneratedBigDecimal should be >= BigDecimal(0.00) + defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) + defaultGeneratedBigDecimal.precision should be(2) + + val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) + + unitIntervalBigDecimal should be >= BigDecimal(0.00) + unitIntervalBigDecimal should be <= BigDecimal(1.00) + unitIntervalBigDecimal.precision should be(8) + } + + it should "be able to generate a specific value from a set of values" in { + + val possibleOptions = Set(1, 3, 5, 123, 0, 9) + + val pick1 = generators.oneOf(possibleOptions) + val pick2 = generators.oneOf(possibleOptions) + val pick3 = generators.oneOf(possibleOptions) + + possibleOptions should contain(pick1) + possibleOptions should contain(pick2) + possibleOptions should contain(pick3) + + val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) + val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) + + possibleOptions should contain(pick4) + possibleOptions should contain(pick5) + possibleOptions should contain(pick6) + + Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 + } + + it should "be able to generate a specific value from an enumeratum enum" in { + + import enumeratum._ + sealed trait TestEnumValue extends EnumEntry + object TestEnum extends Enum[TestEnumValue] { + case object Value1 extends TestEnumValue + case object Value2 extends TestEnumValue + case object Value3 extends TestEnumValue + case object Value4 extends TestEnumValue + val values: IndexedSeq[TestEnumValue] = findValues + } + + val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) + + TestEnum.values should contain allElementsOf picks + picks.toSet.size should be >= 1 + } + + it should "be able to generate array with values generated by generators" in { + + val arrayOfTimes = arrayOf(nextTime(), 16) + arrayOfTimes.length should be <= 16 + + val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) + arrayOfBigDecimals.length should be <= 8 + } + + it should "be able to generate seq with values generated by generators" in { + + val seqOfTimes = seqOf(nextTime(), 16) + seqOfTimes.size should be <= 16 + + val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) + seqOfBigDecimals.size should be <= 8 + } + + it should "be able to generate vector with values generated by generators" in { + + val vectorOfTimes = vectorOf(nextTime(), 16) + vectorOfTimes.size should be <= 16 + + val vectorOfStrings = seqOf(nextString(), 8) + vectorOfStrings.size should be <= 8 + } + + it should "be able to generate list with values generated by generators" in { + + val listOfTimes = listOf(nextTime(), 16) + listOfTimes.size should be <= 16 + + val listOfBigDecimals = seqOf(nextBigDecimal(), 8) + listOfBigDecimals.size should be <= 8 + } + + it should "be able to generate set with values generated by generators" in { + + val setOfTimes = vectorOf(nextTime(), 16) + setOfTimes.size should be <= 16 + + val setOfBigDecimals = seqOf(nextBigDecimal(), 8) + setOfBigDecimals.size should be <= 8 + } + + it should "be able to generate maps with keys and values generated by generators" in { + + val generatedConstantMap = mapOf("key", 123, 10) + generatedConstantMap.size should be <= 1 + assert(generatedConstantMap.keys.forall(_ == "key")) + assert(generatedConstantMap.values.forall(_ == 123)) + + val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) + assert(generatedMap.keys.forall(_.length <= 10)) + assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) + } + + it should "compose deeply" in { + + val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) + + generatedNestedMap.size should be <= 10 + generatedNestedMap.keySet.size should be <= 10 + generatedNestedMap.values.size should be <= 10 + assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala new file mode 100644 index 0000000..fd693f9 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/JsonTest.scala @@ -0,0 +1,521 @@ +package xyz.driver.core + +import java.net.InetAddress +import java.time.{Instant, LocalDate} + +import akka.http.scaladsl.model.Uri +import akka.http.scaladsl.server.PathMatcher +import akka.http.scaladsl.server.PathMatcher.Matched +import com.neovisionaries.i18n.{CountryCode, CurrencyCode} +import enumeratum._ +import eu.timepit.refined.collection.NonEmpty +import eu.timepit.refined.numeric.Positive +import eu.timepit.refined.refineMV +import org.scalatest.{Inspectors, Matchers, WordSpec} +import spray.json._ +import xyz.driver.core.TestTypes.CustomGADT +import xyz.driver.core.auth.AuthCredentials +import xyz.driver.core.domain.{Email, PhoneNumber} +import xyz.driver.core.json._ +import xyz.driver.core.json.enumeratum.HasJsonFormat +import xyz.driver.core.tagging._ +import xyz.driver.core.time.provider.SystemTimeProvider +import xyz.driver.core.time.{Time, TimeOfDay} + +import scala.collection.immutable.IndexedSeq +import scala.language.postfixOps + +class JsonTest extends WordSpec with Matchers with Inspectors { + import DefaultJsonProtocol._ + + "Json format for Id" should { + "read and write correct JSON" in { + + val referenceId = Id[String]("1312-34A") + + val writtenJson = json.idFormat.write(referenceId) + writtenJson.prettyPrint should be("\"1312-34A\"") + + val parsedId = json.idFormat.read(writtenJson) + parsedId should be(referenceId) + } + } + + "Json format for @@" should { + "read and write correct JSON" in { + trait Irrelevant + val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] + + val format = json.taggedFormat[Id[JsonTest], Irrelevant] + + val writtenJson = format.write(reference) + writtenJson shouldBe JsString("SomeID") + + val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) + parsedId shouldBe reference + } + + "read and write correct JSON when there's an implicit conversion defined" in { + val input = " some string " + + JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim() + + val trimmed: String @@ Trimmed = input + trimmed.toJson shouldBe JsString(trimmed) + } + } + + "Json format for Name" should { + "read and write correct JSON" in { + + val referenceName = Name[String]("Homer") + + val writtenJson = json.nameFormat.write(referenceName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedName = json.nameFormat.read(writtenJson) + parsedName should be(referenceName) + } + + "read and write correct JSON for Name @@ Trimmed" in { + trait Irrelevant + JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name") + + val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ") + trimmed.toJson shouldBe JsString("some name") + } + } + + "Json format for NonEmptyName" should { + "read and write correct JSON" in { + + val jsonFormat = json.nonEmptyNameFormat[String] + + val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) + + val writtenJson = jsonFormat.write(referenceNonEmptyName) + writtenJson.prettyPrint should be("\"Homer\"") + + val parsedNonEmptyName = jsonFormat.read(writtenJson) + parsedNonEmptyName should be(referenceNonEmptyName) + } + } + + "Json format for Time" should { + "read and write correct JSON" in { + + val referenceTime = new SystemTimeProvider().currentTime() + + val writtenJson = json.timeFormat.write(referenceTime) + writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") + + val parsedTime = json.timeFormat.read(writtenJson) + parsedTime should be(referenceTime) + } + + "read from inputs compatible with Instant" in { + val referenceTime = new SystemTimeProvider().currentTime() + + val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString)) + + forAll(jsons) { json => + json.convertTo[Time] shouldBe referenceTime + } + } + } + + "Json format for TimeOfDay" should { + "read and write correct JSON" in { + val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") + val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") + val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) + writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) + val parsed = json.timeOfDayFormat.read(writtenJson) + parsed should be(referenceTimeOfDay) + } + } + + "Json format for Date" should { + "read and write correct JSON" in { + import date._ + + val referenceDate = Date(1941, Month.DECEMBER, 7) + + val writtenJson = json.dateFormat.write(referenceDate) + writtenJson.prettyPrint should be("\"1941-12-07\"") + + val parsedDate = json.dateFormat.read(writtenJson) + parsedDate should be(referenceDate) + } + } + + "Json format for java.time.Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + "read correct JSON when value is an epoch milli number" in { + JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an ISO timestamp string" in { + JsString(isoString).convertTo[Instant] shouldBe instant + } + + "read correct JSON when value is an object with nested 'timestamp'/millis field" in { + val json = JsObject( + "timestamp" -> JsNumber(instant.toEpochMilli) + ) + + json.convertTo[Instant] shouldBe instant + } + + "write correct JSON" in { + instant.toJson shouldBe JsString(isoString) + } + } + + "Path matcher for Instant" should { + + val isoString = "2018-08-08T08:08:08.888Z" + val instant = Instant.parse(isoString) + + val matcher = PathMatcher("foo") / InstantInPath / + + "read instant from millis" in { + matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + + "read instant from ISO timestamp string" in { + matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) + } + } + + "Json format for java.time.LocalDate" should { + + "read and write correct JSON" in { + val dateString = "2018-08-08" + val date = LocalDate.parse(dateString) + + date.toJson shouldBe JsString(dateString) + JsString(dateString).convertTo[LocalDate] shouldBe date + } + } + + "Json format for Revision" should { + "read and write correct JSON" in { + + val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") + + val writtenJson = json.revisionFormat.write(referenceRevision) + writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") + + val parsedRevision = json.revisionFormat.read(writtenJson) + parsedRevision should be(referenceRevision) + } + } + + "Json format for Email" should { + "read and write correct JSON" in { + + val referenceEmail = Email("test", "drivergrp.com") + + val writtenJson = json.emailFormat.write(referenceEmail) + writtenJson should be("\"test@drivergrp.com\"".parseJson) + + val parsedEmail = json.emailFormat.read(writtenJson) + parsedEmail should be(referenceEmail) + } + } + + "Json format for PhoneNumber" should { + "read and write correct JSON" in { + + val referencePhoneNumber = PhoneNumber("1", "4243039608") + + val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) + writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) + + val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) + parsedPhoneNumber should be(referencePhoneNumber) + } + + "reject an invalid phone number" in { + val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson + + intercept[DeserializationException] { + json.phoneNumberFormat.read(phoneJson) + }.getMessage shouldBe "Invalid phone number" + } + + "parse phone number from string" in { + JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608") + } + } + + "Path matcher for PhoneNumber" should { + "read valid phone number" in { + val string = "+14243039608x23" + val phone = PhoneNumber("1", "4243039608", Some("23")) + + val matcher = PathMatcher("foo") / PhoneInPath + + matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone)) + } + } + + "Json format for ADT mappings" should { + "read and write correct JSON" in { + + sealed trait EnumVal + case object Val1 extends EnumVal + case object Val2 extends EnumVal + case object Val3 extends EnumVal + + val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) + + val referenceEnumValue1 = Val2 + val referenceEnumValue2 = Val3 + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1.prettyPrint should be("\"b\"") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2.prettyPrint should be("\"c\"") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + } + } + + "Json format for Enums (external)" should { + "read and write correct JSON" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val format = new enumeratum.EnumJsonFormat(MyEnum) + + val referenceEnumValue1 = MyEnum.`Val 2` + val referenceEnumValue2 = MyEnum.`Val/3` + + val writtenJson1 = format.write(referenceEnumValue1) + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = format.write(referenceEnumValue2) + writtenJson2 shouldBe JsString("Val/3") + + val parsedEnumValue1 = format.read(writtenJson1) + val parsedEnumValue2 = format.read(writtenJson2) + + parsedEnumValue1 shouldBe referenceEnumValue1 + parsedEnumValue2 shouldBe referenceEnumValue2 + + intercept[DeserializationException] { + format.read(JsString("Val4")) + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + "Json format for Enums (automatic)" should { + "read and write correct JSON and not require import" in { + + sealed trait MyEnum extends EnumEntry + object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { + case object Val1 extends MyEnum + case object `Val 2` extends MyEnum + case object `Val/3` extends MyEnum + + val values: IndexedSeq[MyEnum] = findValues + } + + val referenceEnumValue1: MyEnum = MyEnum.`Val 2` + val referenceEnumValue2: MyEnum = MyEnum.`Val/3` + + val writtenJson1 = referenceEnumValue1.toJson + writtenJson1 shouldBe JsString("Val 2") + + val writtenJson2 = referenceEnumValue2.toJson + writtenJson2 shouldBe JsString("Val/3") + + import spray.json._ + + val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] + val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] + + parsedEnumValue1 should be(referenceEnumValue1) + parsedEnumValue2 should be(referenceEnumValue2) + + intercept[DeserializationException] { + JsString("Val4").convertTo[MyEnum] + }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" + } + } + + // Should be defined outside of case to have a TypeTag + case class CustomWrapperClass(value: Int) + + "Json format for Value classes" should { + "read and write correct JSON" in { + + val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) + + val referenceValue1 = CustomWrapperClass(-2) + val referenceValue2 = CustomWrapperClass(10) + + val writtenJson1 = format.write(referenceValue1) + writtenJson1.prettyPrint should be("-2") + + val writtenJson2 = format.write(referenceValue2) + writtenJson2.prettyPrint should be("10") + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for classes GADT" should { + "read and write correct JSON" in { + + import CustomGADT._ + import DefaultJsonProtocol._ + implicit val case1Format = jsonFormat1(GadtCase1) + implicit val case2Format = jsonFormat1(GadtCase2) + implicit val case3Format = jsonFormat1(GadtCase3) + + val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { + case _: CustomGADT.GadtCase1 => "case1" + case _: CustomGADT.GadtCase2 => "case2" + case _: CustomGADT.GadtCase3 => "case3" + } { + case "case1" => case1Format + case "case2" => case2Format + case "case3" => case3Format + } + + val referenceValue1 = CustomGADT.GadtCase1("4") + val referenceValue2 = CustomGADT.GadtCase2("Hi!") + + val writtenJson1 = format.write(referenceValue1) + writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) + + val writtenJson2 = format.write(referenceValue2) + writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) + + val parsedValue1 = format.read(writtenJson1) + val parsedValue2 = format.read(writtenJson2) + + parsedValue1 should be(referenceValue1) + parsedValue2 should be(referenceValue2) + } + } + + "Json format for a Refined value" should { + "read and write correct JSON" in { + + val jsonFormat = json.refinedJsonFormat[Int, Positive] + + val referenceRefinedNumber = refineMV[Positive](42) + + val writtenJson = jsonFormat.write(referenceRefinedNumber) + writtenJson should be("42".parseJson) + + val parsedRefinedNumber = jsonFormat.read(writtenJson) + parsedRefinedNumber should be(referenceRefinedNumber) + } + } + + "InetAddress format" should { + "read and write correct JSON" in { + val address = InetAddress.getByName("127.0.0.1") + val json = inetAddressFormat.write(address) + + json shouldBe JsString("127.0.0.1") + + val parsed = inetAddressFormat.read(json) + parsed shouldBe address + } + + "throw a DeserializationException for an invalid IP Address" in { + assertThrows[DeserializationException] { + val invalidAddress = JsString("foobar:") + inetAddressFormat.read(invalidAddress) + } + } + } + + "AuthCredentials format" should { + "read and write correct JSON" in { + val email = Email("someone", "noehere.com") + val phoneId = PhoneNumber.parse("1 207 8675309") + val password = "nopassword" + + phoneId.isDefined should be(true) // test this real quick + + val emailAuth = AuthCredentials(email.toString, password) + val pnAuth = AuthCredentials(phoneId.get.toString, password) + + val emailWritten = authCredentialsFormat.write(emailAuth) + emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) + + val phoneWritten = authCredentialsFormat.write(pnAuth) + phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) + + val identifierEmailParsed = + authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) + var written = authCredentialsFormat.write(identifierEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + val emailEmailParsed = + authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) + written = authCredentialsFormat.write(emailEmailParsed) + written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) + + } + } + + "CountryCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "US" -> CountryCode.US, + "CN" -> CountryCode.CN, + "AT" -> CountryCode.AT + ) + + forAll(samples) { + case (serialized, enumValue) => + countryCodeFormat.write(enumValue) shouldBe JsString(serialized) + countryCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + + "CurrencyCode format" should { + "read and write correct JSON" in { + val samples = Seq( + "USD" -> CurrencyCode.USD, + "CNY" -> CurrencyCode.CNY, + "EUR" -> CurrencyCode.EUR + ) + + forAll(samples) { + case (serialized, enumValue) => + currencyCodeFormat.write(enumValue) shouldBe JsString(serialized) + currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue + } + } + } + +} diff --git a/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala new file mode 100644 index 0000000..bb25deb --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/TestTypes.scala @@ -0,0 +1,14 @@ +package xyz.driver.core + +object TestTypes { + + sealed trait CustomGADT { + val field: String + } + + object CustomGADT { + final case class GadtCase1(field: String) extends CustomGADT + final case class GadtCase2(field: String) extends CustomGADT + final case class GadtCase3(field: String) extends CustomGADT + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala new file mode 100644 index 0000000..324c8d8 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala @@ -0,0 +1,89 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.{HttpMethod, StatusCodes} +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.config.ConfigFactory +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.app.{DriverApp, SimpleModule} + +class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { + val config = ConfigFactory.parseString(""" + |application { + | cors { + | allowedOrigins: ["example.com"] + | } + |} + """.stripMargin).withFallback(ConfigFactory.load) + + val origin = Origin(HttpOrigin("https", Host("example.com"))) + val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) + val allowedMethods: collection.immutable.Seq[HttpMethod] = { + import akka.http.scaladsl.model.HttpMethods._ + collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) + } + + import scala.reflect.runtime.universe.typeOf + class TestApp(testRoute: Route) + extends DriverApp( + appName = "test-app", + version = "0.0.1", + gitHash = "deadb33f", + modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), + config = config, + log = xyz.driver.core.logging.NoLogger + ) + + it should "respond with the correct CORS headers for the swagger OPTIONS route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods + } + } + + it should "respond with the correct CORS headers for the test route" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "respond with the correct CORS headers for a concatenated route" in { + val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) + Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) + } + } + + it should "allow subdomains of allowed origin suffixes" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) + } + } + + it should "respond with default domains for invalid origins" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") + .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) + } + } + + it should "respond with Pragma and Cache-Control (no-cache) headers" in { + val route = new TestApp(get(complete(StatusCodes.OK))) + Get(s"/api/v1/test") ~> route.appRoute ~> check { + status shouldBe StatusCodes.OK + header("Pragma").map(_.value()) should contain("no-cache") + header[`Cache-Control`].map(_.value()) should contain("no-cache") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala new file mode 100644 index 0000000..cc0019a --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala @@ -0,0 +1,121 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.model.headers.Connection +import akka.http.scaladsl.server.Directives.{complete => akkaComplete} +import akka.http.scaladsl.server.{Directives, RejectionHandler, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import com.typesafe.scalalogging.Logger +import org.scalatest.{AsyncFlatSpec, Matchers} +import xyz.driver.core.json.serviceExceptionFormat +import xyz.driver.core.logging.NoLogger +import xyz.driver.core.rest.errors._ + +import scala.concurrent.Future + +class DriverRouteTest + extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { + class TestRoute(override val route: Route) extends DriverRoute { + override def log: Logger = NoLogger + } + + "DriverRoute" should "respond with 200 OK for a basic route" in { + val route = new TestRoute(akkaComplete(StatusCodes.OK)) + + Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.OK + } + } + + it should "respond with a 401 for an InvalidInputException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe InvalidInputException() + } + } + + it should "respond with a 403 for InvalidActionException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.Forbidden + responseAs[ServiceException] shouldBe InvalidActionException() + } + } + + it should "respond with a 404 for ResourceNotFoundException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + responseAs[ServiceException] shouldBe ResourceNotFoundException() + } + } + + it should "respond with a 500 for ExternalServiceException" in { + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe error + } + } + + it should "allow pass-through of external service exceptions" in { + val innerError = InvalidInputException() + val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) + val future = Future.failed[String](error) + val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + responseAs[ServiceException] shouldBe innerError + } + } + + it should "respond with a 503 for ExternalServiceTimeoutException" in { + val error = ExternalServiceTimeoutException("GET /api/v1/users/") + val route = new TestRoute(akkaComplete(Future.failed[String](error))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.GatewayTimeout + responseAs[ServiceException] shouldBe error + } + } + + it should "respond with a 500 for DatabaseException" in { + val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) + + Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { + handled shouldBe true + status shouldBe StatusCodes.InternalServerError + responseAs[ServiceException] shouldBe DatabaseException() + } + } + + it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { + val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() + val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) + + Get("/foo") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.OK + headers should contain(Connection("close")) + } + + Get("/bar") ~> route.routeWithDefaults ~> check { + status shouldBe StatusCodes.NotFound + headers should contain(Connection("close")) + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala new file mode 100644 index 0000000..987717d --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala @@ -0,0 +1,101 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport +import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.headers.`Content-Type` +import akka.http.scaladsl.server.{Directives, Route} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import org.scalatest.{FlatSpec, Matchers} +import spray.json._ +import xyz.driver.core.{Id, Name} +import xyz.driver.core.json._ + +import scala.concurrent.Future + +class PatchDirectivesTest + extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol + with Directives with PatchDirectives { + case class Bar(name: Name[Bar], size: Int) + case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) + implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) + implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) + + val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) + + def route(retrieve: => Future[Option[Foo]]): Route = + Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => + entity(as[Patchable[Foo]]) { fooPatchable => + mergePatch(fooPatchable, retrieve) { updatedFoo => + complete(updatedFoo) + } + } + }) + + val MergePatchContentType = ContentType(`application/merge-patch+json`) + val ContentTypeHeader = `Content-Type`(MergePatchContentType) + def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = + HttpEntity(contentType, json) + + "PatchSupport" should "allow partial updates to an existing object" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4) + } + } + + it should "merge deeply nested objects" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 404 if the object is not found" in { + val fooRetrieve = Future.successful(None) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.NotFound + } + } + + it should "handle nulls on optional values correctly" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = None) + } + } + + it should "handle optional values correctly when old value is null" in { + val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) + + Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) + } + } + + it should "return a 400 for nulls on non-optional values" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { + handled shouldBe true + status shouldBe StatusCodes.BadRequest + } + } + + it should "return a 415 for incorrect Content-Type" in { + val fooRetrieve = Future.successful(Some(testFoo)) + + Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { + status shouldBe StatusCodes.UnsupportedMediaType + responseAs[String] should include("application/merge-patch+json") + } + } +} diff --git a/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala new file mode 100644 index 0000000..19e4ed1 --- /dev/null +++ b/core-rest/src/test/scala/xyz/driver/core/rest/RestTest.scala @@ -0,0 +1,151 @@ +package xyz.driver.core.rest + +import akka.http.scaladsl.model.StatusCodes +import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} +import akka.http.scaladsl.testkit.ScalatestRouteTest +import akka.util.ByteString +import org.scalatest.{Matchers, WordSpec} +import xyz.driver.core.rest + +import scala.concurrent.Future +import scala.util.Random + +class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { + "`escapeScriptTags` function" should { + "escape script tags properly" in { + val dirtyString = " + complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "provide a default pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,100") + } + } + "provide default values for a partial pagination" in { + Get("/?pageSize=2") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "1,2") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=-1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "optional paginated directive" should { + val route: Route = rest.optionalPagination { paginated => + complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) + } + "accept a pagination" in { + Get("/?pageNumber=2&pageSize=42") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "2,42") + } + } + "without pagination" in { + Get("/") ~> route ~> check { + assert(status == StatusCodes.OK) + assert(entityAs[String] == "no pagination") + } + } + "reject an invalid pagination" in { + Get("/?pageNumber=1") ~> route ~> check { + assert(rejection.isInstanceOf[ValidationRejection]) + } + } + } + + "completeWithPagination directive" when { + import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ + import spray.json.DefaultJsonProtocol._ + + val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString) + val route: Route = + parameter('empty.as[Boolean] ? false) { isEmpty => + completeWithPagination[String] { + case Some(pagination) if isEmpty => + Future.successful(ListResponse(Seq(), 0, Some(pagination))) + case Some(pagination) => + val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize) + Future.successful(ListResponse(filtered, data.size, Some(pagination))) + case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None)) + case None => Future.successful(ListResponse(data, data.size, None)) + } + } + + "pagination is defined" should { + "return a response with pagination headers" in { + Get("/?pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe data.slice(10, 20) + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("11") + } + } + + "disallow pageSize <= 0" in { + Get("/?pageNumber=2&pageSize=0") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=2&pageSize=-1") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "disallow pageNumber <= 0" in { + Get("/?pageNumber=0&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + + Get("/?pageNumber=-1&pageSize=10") ~> route ~> check { + rejection shouldBe a[ValidationRejection] + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + + "pagination is not defined" should { + "return a response with pagination headers and PageCount == 1" in { + Get("/") ~> route ~> check { + responseAs[Seq[String]] shouldBe data + header(ContextHeaders.ResourceCount).map(_.value) should contain("103") + header(ContextHeaders.PageCount).map(_.value) should contain("1") + } + } + + "return PageCount == 0 if returning an empty list" in { + Get("/?empty=true") ~> route ~> check { + responseAs[Seq[String]] shouldBe empty + header(ContextHeaders.ResourceCount).map(_.value) should contain("0") + header(ContextHeaders.PageCount).map(_.value) should contain("0") + } + } + } + } +} -- cgit v1.2.3