From 4d1197099ce4e721c18bf4cacbb2e1980e4210b5 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Wed, 12 Sep 2018 16:40:57 -0700 Subject: Move REST functionality to separate project --- src/main/scala/xyz/driver/core/auth.scala | 43 --- src/main/scala/xyz/driver/core/generators.scala | 143 -------- src/main/scala/xyz/driver/core/json.scala | 398 --------------------- .../scala/xyz/driver/core/rest/DnsDiscovery.scala | 11 - .../scala/xyz/driver/core/rest/DriverRoute.scala | 122 ------- .../core/rest/HttpRestServiceTransport.scala | 103 ------ .../xyz/driver/core/rest/PatchDirectives.scala | 104 ------ .../xyz/driver/core/rest/PooledHttpClient.scala | 67 ---- .../scala/xyz/driver/core/rest/ProxyRoute.scala | 26 -- .../scala/xyz/driver/core/rest/RestService.scala | 86 ----- .../xyz/driver/core/rest/ServiceDescriptor.scala | 16 - .../driver/core/rest/SingleRequestHttpClient.scala | 29 -- src/main/scala/xyz/driver/core/rest/Swagger.scala | 144 -------- .../core/rest/auth/AlwaysAllowAuthorization.scala | 14 - .../xyz/driver/core/rest/auth/AuthProvider.scala | 75 ---- .../xyz/driver/core/rest/auth/Authorization.scala | 11 - .../core/rest/auth/AuthorizationResult.scala | 22 -- .../core/rest/auth/CachedTokenAuthorization.scala | 55 --- .../core/rest/auth/ChainedAuthorization.scala | 27 -- .../core/rest/directives/AuthDirectives.scala | 19 - .../core/rest/directives/CorsDirectives.scala | 72 ---- .../driver/core/rest/directives/Directives.scala | 6 - .../driver/core/rest/directives/PathMatchers.scala | 85 ----- .../core/rest/directives/Unmarshallers.scala | 40 --- .../driver/core/rest/errors/serviceException.scala | 27 -- .../xyz/driver/core/rest/headers/Traceparent.scala | 33 -- src/main/scala/xyz/driver/core/rest/package.scala | 323 ----------------- .../xyz/driver/core/rest/serviceDiscovery.scala | 24 -- .../driver/core/rest/serviceRequestContext.scala | 87 ----- 29 files changed, 2212 deletions(-) delete mode 100644 src/main/scala/xyz/driver/core/auth.scala delete mode 100644 src/main/scala/xyz/driver/core/generators.scala delete mode 100644 src/main/scala/xyz/driver/core/json.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/DriverRoute.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PatchDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ProxyRoute.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/RestService.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/Swagger.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/Authorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/Directives.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/errors/serviceException.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/package.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala delete mode 100644 src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala (limited to 'src/main/scala/xyz') diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala deleted file mode 100644 index 896bd89..0000000 --- a/src/main/scala/xyz/driver/core/auth.scala +++ /dev/null @@ -1,43 +0,0 @@ -package xyz.driver.core - -import xyz.driver.core.domain.Email -import xyz.driver.core.time.Time -import scalaz.Equal - -object auth { - - trait Permission - - final case class Role(id: Id[Role], name: Name[Role]) { - - def oneOf(roles: Role*): Boolean = roles.contains(this) - - def oneOf(roles: Set[Role]): Boolean = roles.contains(this) - } - - object Role { - implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) - } - - trait User { - def id: Id[User] - } - - final case class AuthToken(value: String) - - final case class AuthTokenUserInfo( - id: Id[User], - email: Email, - emailVerified: Boolean, - audience: String, - roles: Set[Role], - expirationTime: Time) - extends User - - final case class RefreshToken(value: String) - final case class PermissionsToken(value: String) - - final case class PasswordHash(value: String) - - final case class AuthCredentials(identifier: String, password: String) -} diff --git a/src/main/scala/xyz/driver/core/generators.scala b/src/main/scala/xyz/driver/core/generators.scala deleted file mode 100644 index d00b6dd..0000000 --- a/src/main/scala/xyz/driver/core/generators.scala +++ /dev/null @@ -1,143 +0,0 @@ -package xyz.driver.core - -import enumeratum._ -import java.math.MathContext -import java.time.{Instant, LocalDate, ZoneOffset} -import java.util.UUID - -import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} -import xyz.driver.core.date.{Date, DayOfWeek} - -import scala.reflect.ClassTag -import scala.util.Random -import eu.timepit.refined.refineV -import eu.timepit.refined.api.Refined -import eu.timepit.refined.collection._ - -object generators { - - private val random = new Random - import random._ - private val secureRandom = new java.security.SecureRandom() - - private val DefaultMaxLength = 10 - private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet - private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" - private val Numbers = "0123456789" - - private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { - val builder = new StringBuilder - for (_ <- 0 until length) { - builder += chars(secureRandom.nextInt(chars.length)) - } - builder.result() - } - - /** Creates a random invitation token. - * - * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It - * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) - - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) - - def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue - - def nextBoolean(): Boolean = random.nextBoolean() - - def nextDouble(): Double = random.nextDouble() - - def nextId[T](): Id[T] = Id[T](nextUuid().toString) - - def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) - - def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) - - def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) - - def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) - - def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = - NonEmptyName[T](nextNonEmptyString(maxLength)) - - def nextUuid(): UUID = java.util.UUID.randomUUID - - def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) - - def nextString(maxLength: Int = DefaultMaxLength): String = - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - - def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { - refineV[NonEmpty]( - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - ).right.get - } - - def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None - - def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) - - def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) - - def nextInstant(): Instant = Instant.ofEpochMilli(math.abs(nextLong() % System.currentTimeMillis)) - - def nextTime(): Time = nextInstant() - - def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) - - def nextTimeRange(): TimeRange = { - val oneTime = nextTime() - val anotherTime = nextTime() - - TimeRange( - Time(scala.math.min(oneTime.millis, anotherTime.millis)), - Time(scala.math.max(oneTime.millis, anotherTime.millis))) - } - - def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) - - def nextLocalDate(): LocalDate = nextInstant().atZone(ZoneOffset.UTC).toLocalDate - - def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) - - def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = - BigDecimal(multiplier * nextDouble, new MathContext(precision)) - - def oneOf[T](items: T*): T = oneOf(items.toSet) - - def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) - - def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) - - def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = - Array.fill(nextInt(maxLength, minLength))(generator) - - def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = - Seq.fill(nextInt(maxLength, minLength))(generator) - - def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = - Vector.fill(nextInt(maxLength, minLength))(generator) - - def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = - List.fill(nextInt(maxLength, minLength))(generator) - - def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = - seqOf(generator, maxLength, minLength).toSet - - def mapOf[K, V]( - keyGenerator: => K, - valueGenerator: => V, - maxLength: Int = DefaultMaxLength, - minLength: Int = 0): Map[K, V] = - seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap -} diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala deleted file mode 100644 index edc2347..0000000 --- a/src/main/scala/xyz/driver/core/json.scala +++ /dev/null @@ -1,398 +0,0 @@ -package xyz.driver.core - -import java.net.InetAddress -import java.time.format.DateTimeFormatter -import java.time.{Instant, LocalDate} -import java.util.{TimeZone, UUID} - -import akka.http.scaladsl.unmarshalling.Unmarshaller -import com.neovisionaries.i18n.{CountryCode, CurrencyCode} -import enumeratum._ -import eu.timepit.refined.api.{Refined, Validate} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import spray.json._ -import xyz.driver.core.auth.AuthCredentials -import xyz.driver.core.date.{Date, DayOfWeek, Month} -import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} -import xyz.driver.core.rest.errors._ -import xyz.driver.core.time.{Time, TimeOfDay} - -import scala.reflect.runtime.universe._ -import scala.reflect.{ClassTag, classTag} -import scala.util.Try -import scala.util.control.NonFatal - -object json extends PathMatchers with Unmarshallers { - import DefaultJsonProtocol._ - - implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { - def write(id: Id[T]) = JsString(id.value) - - def read(value: JsValue): Id[T] = value match { - case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) - case JsString(id) => Id[T](id) - case _ => throw DeserializationException("Id expects string") - } - } - - implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F], convert: F => F @@ T = null): JsonFormat[F @@ T] = - new JsonFormat[F @@ T] { - import tagging._ - - private val transformReadValue = Option(convert).getOrElse((_: F).tagged[T]) - - override def write(obj: F @@ T): JsValue = underlying.write(obj) - - override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) - } - - implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { - def write(name: Name[T]) = JsString(name.value) - - def read(value: JsValue): Name[T] = value match { - case JsString(name) => Name[T](name) - case _ => throw DeserializationException("Name expects string") - } - } - - implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { - def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) - - def read(value: JsValue): Time = Time(instantFormat.read(value)) - } - - implicit val instantFormat: JsonFormat[Instant] = new JsonFormat[Instant] { - def write(instant: Instant): JsValue = JsString(instant.toString) - - def read(value: JsValue): Instant = value match { - case JsObject(fields) => - fields - .get("timestamp") - .flatMap { - case JsNumber(millis) => Some(Instant.ofEpochMilli(millis.longValue())) - case _ => None - } - .getOrElse(deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}")) - case JsNumber(millis) => Instant.ofEpochMilli(millis.longValue()) - case JsString(str) => - try Instant.parse(str) - catch { case NonFatal(_) => deserializationError(s"Instant expects ISO timestamp but got $str") } - case _ => deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}") - } - } - - implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { - private val formatter = TimeOfDay.getFormatter - def read(json: JsValue): java.time.LocalTime = json match { - case JsString(chars) => - java.time.LocalTime.parse(chars) - case _ => deserializationError(s"Expected time string got ${json.toString}") - } - - def write(obj: java.time.LocalTime): JsValue = { - JsString(obj.format(formatter)) - } - } - - implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { - override def write(obj: TimeZone): JsValue = { - JsString(obj.getID()) - } - - override def read(json: JsValue): TimeZone = json match { - case JsString(chars) => - java.util.TimeZone.getTimeZone(chars) - case _ => deserializationError(s"Expected time zone string got ${json.toString}") - } - } - - implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) - - implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) - - implicit val dateFormat = new RootJsonFormat[Date] { - def write(date: Date) = JsString(date.toString) - def read(value: JsValue): Date = value match { - case JsString(dateString) => - Date - .fromString(dateString) - .getOrElse( - throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) - case _ => throw DeserializationException(s"Date expects a string, but got $value.") - } - } - - implicit val localDateFormat = new RootJsonFormat[LocalDate] { - val format = DateTimeFormatter.ISO_LOCAL_DATE - - def write(date: LocalDate): JsValue = JsString(date.format(format)) - def read(value: JsValue): LocalDate = value match { - case JsString(dateString) => - try LocalDate.parse(dateString, format) - catch { - case NonFatal(_) => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.") - } - case _ => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got ${value.compactPrint}.") - } - } - - implicit val monthFormat = new RootJsonFormat[Month] { - def write(month: Month) = JsNumber(month) - def read(value: JsValue): Month = value match { - case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) - case _ => throw DeserializationException("Expected a number from 0 to 11") - } - } - - implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { - def write(revision: Revision[T]) = JsString(revision.id.toString) - - def read(value: JsValue): Revision[T] = value match { - case JsString(revision) => Revision[T](revision) - case _ => throw DeserializationException("Revision expects uuid string") - } - } - - implicit val base64Format = new RootJsonFormat[Base64] { - def write(base64Value: Base64) = JsString(base64Value.value) - - def read(value: JsValue): Base64 = value match { - case JsString(base64Value) => Base64(base64Value) - case _ => throw DeserializationException("Base64 format expects string") - } - } - - implicit val emailFormat = new RootJsonFormat[Email] { - def write(email: Email) = JsString(email.username + "@" + email.domain) - def read(json: JsValue): Email = json match { - - case JsString(value) => - Email.parse(value).getOrElse { - deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) - } - - case _ => - deserializationError("Expected string as Email, but got " + json.toString) - } - } - - implicit object phoneNumberFormat extends RootJsonFormat[PhoneNumber] { - - private val basicFormat = jsonFormat3(PhoneNumber.apply) - - def write(obj: PhoneNumber): JsValue = basicFormat.write(obj) - - def read(json: JsValue): PhoneNumber = { - val maybePhone = json match { - case JsString(number) => PhoneNumber.parse(number) - case obj: JsObject => PhoneNumber.parse(basicFormat.read(obj).toString) - case _ => None - } - maybePhone.getOrElse(deserializationError("Invalid phone number")) - } - } - - implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { - override def read(json: JsValue): AuthCredentials = { - json match { - case JsObject(fields) => - val emailField = fields.get("email") - val identifierField = fields.get("identifier") - val passwordField = fields.get("password") - - (emailField, identifierField, passwordField) match { - case (_, _, None) => - deserializationError("password field must be set") - case (Some(JsString(em)), _, Some(JsString(pw))) => - val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) - AuthCredentials(email.toString, pw) - case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) - case (None, None, _) => deserializationError("identifier must be provided") - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - } - - override def write(obj: AuthCredentials): JsValue = JsObject( - "identifier" -> JsString(obj.identifier), - "password" -> JsString(obj.password) - ) - } - - implicit object inetAddressFormat extends JsonFormat[InetAddress] { - override def read(json: JsValue): InetAddress = json match { - case JsString(ipString) => - Try(InetAddress.getByName(ipString)) - .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) - case _ => deserializationError(s"Expected string for IP Address, got $json") - } - - override def write(obj: InetAddress): JsValue = - JsString(obj.getHostAddress) - } - - implicit val countryCodeFormat: JsonFormat[CountryCode] = javaEnumFormat[CountryCode] - - implicit val currencyCodeFormat: JsonFormat[CurrencyCode] = javaEnumFormat[CurrencyCode] - - object enumeratum { - - def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - - trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => - - implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) - - implicit val unmarshaller: Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - } - - class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { - override def read(json: JsValue): T = json match { - case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - - override def write(obj: T): JsValue = JsString(obj.entryName) - } - - private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = - deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") - } - - class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { - private val map = mapping.toMap - - override def write(value: T): JsValue = { - map.find(_._2 == value).map(_._1) match { - case Some(name) => JsString(name) - case _ => serializationError(s"Value $value is not found in the mapping $map") - } - } - - override def read(json: JsValue): T = json match { - case JsString(name) => - map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - } - - def javaEnumFormat[T <: java.lang.Enum[_]: ClassTag]: JsonFormat[T] = { - val values = classTag[T].runtimeClass.asInstanceOf[Class[T]].getEnumConstants - new EnumJsonFormat[T](values.map(v => v.name() -> v): _*) - } - - class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { - def write(valueClass: T) = JsNumber(writeValue(valueClass)) - def read(json: JsValue): T = json match { - case JsNumber(value) => create(value) - case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - class GadtJsonFormat[T: TypeTag]( - typeField: String, - typeValue: PartialFunction[T, String], - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) - extends RootJsonFormat[T] { - - def write(value: T): JsValue = { - - val valueType = typeValue.applyOrElse(value, { v: T => - deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) - }) - - val valueFormat = - jsonFormat.applyOrElse(valueType, { f: String => - deserializationError(s"No Json format for this type of $valueType") - }) - - valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { - case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) - case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") - } - } - - def read(json: JsValue): T = json match { - case JsObject(fields) => - val valueJson = JsObject(fields.filterNot(_._1 == typeField)) - fields(typeField) match { - case JsString(valueType) => - val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - }) - valueFormat.read(valueJson) - case _ => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - } - case _ => - deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - object GadtJsonFormat { - - def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { - - new GadtJsonFormat[T](typeField, typeValue, jsonFormat) - } - } - - /** - * Provides the JsonFormat for the Refined types provided by the Refined library. - * - * @see https://github.com/fthomas/refined - */ - implicit def refinedJsonFormat[T, Predicate]( - implicit valueFormat: JsonFormat[T], - validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = - new JsonFormat[Refined[T, Predicate]] { - def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) - def read(value: JsValue): T Refined Predicate = { - refineV[Predicate](valueFormat.read(value))(validate) match { - case Right(refinedValue) => refinedValue - case Left(refinementError) => deserializationError(refinementError) - } - } - } - - implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = - new RootJsonFormat[NonEmptyName[T]] { - def write(name: NonEmptyName[T]) = JsString(name.value.value) - - def read(value: JsValue): NonEmptyName[T] = - NonEmptyName[T](nonEmptyStringFormat.read(value)) - } - - implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = - GadtJsonFormat.create[ServiceException]("type") { - case _: InvalidInputException => "InvalidInputException" - case _: InvalidActionException => "InvalidActionException" - case _: UnauthorizedException => "UnauthorizedException" - case _: ResourceNotFoundException => "ResourceNotFoundException" - case _: ExternalServiceException => "ExternalServiceException" - case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" - case _: DatabaseException => "DatabaseException" - } { - case "InvalidInputException" => jsonFormat(InvalidInputException, "message") - case "InvalidActionException" => jsonFormat(InvalidActionException, "message") - case "UnauthorizedException" => jsonFormat(UnauthorizedException, "message") - case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") - case "ExternalServiceException" => - jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") - case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") - case "DatabaseException" => jsonFormat(DatabaseException, "message") - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala b/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala deleted file mode 100644 index 87946e4..0000000 --- a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core -package rest - -class DnsDiscovery(transport: HttpRestServiceTransport, overrides: Map[String, String]) { - - def discover[A](implicit descriptor: ServiceDescriptor[A]): A = { - val url = overrides.getOrElse(descriptor.name, s"https://${descriptor.name}") - descriptor.connect(transport, url) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala deleted file mode 100644 index 911e306..0000000 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ /dev/null @@ -1,122 +0,0 @@ -package xyz.driver.core.rest - -import java.sql.SQLException - -import akka.http.scaladsl.model.headers.CacheDirectives.`no-cache` -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{StatusCodes, _} -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import com.typesafe.scalalogging.Logger -import org.slf4j.MDC -import xyz.driver.core.rest -import xyz.driver.core.rest.errors._ - -import scala.compat.Platform.ConcurrentModificationException - -trait DriverRoute { - def log: Logger - - def route: Route - - def routeWithDefaults: Route = { - (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { - route - } - } - - protected def defaultResponseHeaders: Directive0 = { - extractRequest flatMap { request => - // Needs to happen before any request processing, so all the log messages - // associated with processing of this request are having this `trackingId` - val trackingId = rest.extractTrackingId(request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - MDC.put("trackingId", trackingId) - - respondWithHeaders(tracingHeader +: DriverRoute.DefaultHeaders: _*) - } - } - - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - case serviceException: ServiceException => - serviceExceptionHandler(serviceException) - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( - ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) - - case t: Exception => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) - } - - protected def serviceExceptionHandler(serviceException: ServiceException): Route = { - val statusCode = serviceException match { - case e: InvalidInputException => - log.info("Invalid client input error", e) - StatusCodes.BadRequest - case e: InvalidActionException => - log.info("Invalid client action error", e) - StatusCodes.Forbidden - case e: UnauthorizedException => - log.info("Unauthorized user error", e) - StatusCodes.Unauthorized - case e: ResourceNotFoundException => - log.info("Resource not found error", e) - StatusCodes.NotFound - case e: ExternalServiceException => - log.error("Error while calling another service", e) - StatusCodes.InternalServerError - case e: ExternalServiceTimeoutException => - log.error("Service timeout error", e) - StatusCodes.GatewayTimeout - case e: DatabaseException => - log.error("Database error", e) - StatusCodes.InternalServerError - } - - { (ctx: RequestContext) => - import xyz.driver.core.json.serviceExceptionFormat - val entity = - HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) - errorResponse(statusCode, entity, serviceException)(ctx) - } - } - - protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = - errorResponse(statusCode, HttpEntity(message), exception) - - protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { - complete(HttpResponse(statusCode, entity = entity)) - } - -} - -object DriverRoute { - val DefaultHeaders: List[HttpHeader] = List( - // This header will eliminate the risk of envoy trying to reuse a connection - // that already timed out on the server side by completely rejecting keep-alive - Connection("close"), - // These 2 headers are the simplest way to prevent IE from caching GET requests - RawHeader("Pragma", "no-cache"), - `Cache-Control`(List(`no-cache`(Nil))) - ) -} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala deleted file mode 100644 index e31635b..0000000 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ /dev/null @@ -1,103 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.RawHeader -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.TcpIdleTimeoutException -import org.slf4j.MDC -import xyz.driver.core.Name -import xyz.driver.core.reporting.Reporter -import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -class HttpRestServiceTransport( - applicationName: Name[App], - applicationVersion: String, - val actorSystem: ActorSystem, - val executionContext: ExecutionContext, - reporter: Reporter) - extends ServiceTransport { - - protected implicit val execution: ExecutionContext = executionContext - - protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - val tags = Map( - // open tracing semantic tags - "span.kind" -> "client", - "service" -> applicationName.value, - "http.url" -> requestStub.uri.toString, - "http.method" -> requestStub.method.value, - "peer.hostname" -> requestStub.uri.authority.host.toString, - // google's tracing console provides extra search features if we define these tags - "/http/path" -> requestStub.uri.path.toString, - "/http/method" -> requestStub.method.value.toString, - "/http/url" -> requestStub.uri.toString - ) - reporter.traceAsync(s"http_call_rpc", tags) { implicit span => - val requestTime = System.currentTimeMillis() - - val request = requestStub - .withHeaders(context.contextHeaders.toSeq.map { - case (ContextHeaders.TrackingIdHeader, _) => - RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) - case (ContextHeaders.StacktraceHeader, _) => - RawHeader( - ContextHeaders.StacktraceHeader, - Option(MDC.get("stack")) - .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) - .getOrElse("")) - case (header, headerValue) => RawHeader(header, headerValue) - }: _*) - - reporter.debug(s"Sending request to ${request.method} ${request.uri}") - - val response = httpClient.makeRequest(request) - - response.onComplete { - case Success(r) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.debug( - s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") - - case Failure(t: Throwable) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.warn( - s"Failed to receive response from ${request.method.value} ${request.uri} in $responseLatency ms", - t) - }(executionContext) - - response.recoverWith { - case _: TcpIdleTimeoutException => - val serviceCalled = s"${requestStub.method.value} ${requestStub.uri}" - Future.failed(ExternalServiceTimeoutException(serviceCalled)) - case t: Throwable => Future.failed(t) - } - }(context.spanContext) - } - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { - - sendRequestGetResponse(context)(requestStub) flatMap { response => - if (response.status == StatusCodes.NotFound) { - Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) - } else if (response.status.isFailure()) { - val serviceCalled = s"${requestStub.method} ${requestStub.uri}" - Unmarshal(response.entity).to[String] flatMap { errorString => - import spray.json._ - import xyz.driver.core.json._ - val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption - Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) - } - } else { - Future.successful(Unmarshal(response.entity)) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala deleted file mode 100644 index f33bf9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala +++ /dev/null @@ -1,104 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.javadsl.server.Rejections -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} -import spray.json._ - -import scala.concurrent.Future -import scala.util.{Failure, Success, Try} - -trait PatchDirectives extends Directives with SprayJsonSupport { - - /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - val `application/merge-patch+json`: MediaType.WithFixedCharset = - MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) - - /** Wraps a JSON value that represents a patch. - * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - case class PatchValue(value: JsValue) { - - /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ - def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) - } - - /** Witness that the given patch may be applied to an original domain value. - * @tparam A type of the domain value - * @param patch the patch that may be applied to a domain value - * @param format a JSON format that enables serialization and deserialization of a domain value */ - case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { - - /** Applies the patch to a given domain object. The result will be a combination - * of the original value, updates with the fields specified in this witness' patch. */ - def applyTo(original: A): A = { - val serialized = format.write(original) - val merged = patch.applyTo(serialized) - val deserialized = format.read(merged) - deserialized - } - } - - implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = - Unmarshaller.byteStringUnmarshaller - .andThen(sprayJsValueByteStringUnmarshaller) - .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) - .map(js => PatchValue(js)) - - implicit def patchableUnmarshaller[A]( - implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], - format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { - patchUnmarshaller.map(patch => Patchable[A](patch, format)) - } - - protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { - JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => - val oldValue = oldObj.fields.getOrElse(key, JsNull) - val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) - key -> newValue - })(collection.breakOut): _*) - } - - protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { - def mergeError(typ: String): Nothing = - deserializationError(s"Expected $typ value, got $newValue") - - if (maxLevels.exists(_ < 0)) oldValue - else { - (oldValue, newValue) match { - case (_: JsString, newString @ (JsString(_) | JsNull)) => newString - case (_: JsString, _) => mergeError("string") - case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber - case (_: JsNumber, _) => mergeError("number") - case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool - case (_: JsBoolean, _) => mergeError("boolean") - case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr - case (_: JsArray, _) => mergeError("array") - case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) - case (_: JsObject, JsNull) => JsNull - case (_: JsObject, _) => mergeError("object") - case (JsNull, _) => newValue - } - } - } - - def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = - Directive { inner => requestCtx => - onSuccess(retrieve)({ - case Some(oldT) => - Try(patchable.applyTo(oldT)) - .transform[Route]( - mergedT => scala.util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => reject() - })(requestCtx) - } -} - -object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala deleted file mode 100644 index 2854257..0000000 --- a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala +++ /dev/null @@ -1,67 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.scaladsl.{Keep, Sink, Source} -import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} -import xyz.driver.core.Name - -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.concurrent.duration._ -import scala.util.{Failure, Success} - -class PooledHttpClient( - baseUri: Uri, - applicationName: Name[App], - applicationVersion: String, - requestRateLimit: Int = 64, - requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) - extends HttpClient { - - private val host = baseUri.authority.host.toString() - private val port = baseUri.effectivePort - private val scheme = baseUri.scheme - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - private val pool = if (scheme.equalsIgnoreCase("https")) { - Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } else { - Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } - - private val queue = Source - .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) - .via(pool) - .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) - .toMat(Sink.foreach({ - case ((Success(resp), p)) => p.success(resp) - case ((Failure(e), p)) => p.failure(e) - }))(Keep.left) - .run - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - val responsePromise = Promise[HttpResponse]() - - queue.offer(request -> responsePromise).flatMap { - case QueueOfferResult.Enqueued => - responsePromise.future - case QueueOfferResult.Dropped => - Future.failed(new Exception(s"Request queue to the host $host is overflown")) - case QueueOfferResult.Failure(ex) => - Future.failed(ex) - case QueueOfferResult.QueueClosed => - Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala deleted file mode 100644 index c0e9f99..0000000 --- a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala +++ /dev/null @@ -1,26 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} -import com.typesafe.config.Config -import xyz.driver.core.Name - -import scala.concurrent.ExecutionContext - -trait ProxyRoute extends DriverRoute { - implicit val executionContext: ExecutionContext - val config: Config - val httpClient: HttpClient - - protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => - val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") - val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") - - val originalUri = ctx.request.uri - val originalRequest = ctx.request - - val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) - val newRequest = originalRequest.withUri(newUri) - - httpClient.makeRequest(newRequest).map(RouteResult.Complete) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/RestService.scala b/src/main/scala/xyz/driver/core/rest/RestService.scala deleted file mode 100644 index 09d98b8..0000000 --- a/src/main/scala/xyz/driver/core/rest/RestService.scala +++ /dev/null @@ -1,86 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model._ -import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} -import akka.stream.Materializer - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.{ListT, OptionT} - -trait RestService extends Service { - - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json._ - - protected implicit val exec: ExecutionContext - protected implicit val materializer: Materializer - - implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { - def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = - if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] - } - - protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = - OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) - - protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = - OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) - - protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = - ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) - - protected def jsonEntity(json: JsValue): RequestEntity = - HttpEntity(ContentTypes.`application/json`, json.compactPrint) - - protected def mergePatchJsonEntity(json: JsValue): RequestEntity = - HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) - - protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) - - protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) - - protected def postJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) - - protected def putJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) - - protected def patchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) - - protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) - - protected def endpointUri(baseUri: Uri, path: String): Uri = - baseUri.withPath(Uri.Path(path)) - - protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = - baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) - - protected def responseToListResponse[T: JsonFormat](pagination: Option[Pagination])( - response: HttpResponse): Future[ListResponse[T]] = { - import DefaultJsonProtocol._ - val resourceCount = response.headers - .find(_.name() equalsIgnoreCase ContextHeaders.ResourceCount) - .map(_.value().toInt) - .getOrElse(0) - val meta = ListResponse.Meta(resourceCount, pagination.getOrElse(Pagination(resourceCount max 1, 1))) - Unmarshal(response.entity).to[List[T]].map(ListResponse(_, meta)) - } - - protected def responseToListResponse[T: JsonFormat](pagination: Pagination)( - response: HttpResponse): Future[ListResponse[T]] = responseToListResponse(Some(pagination))(response) -} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala b/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala deleted file mode 100644 index 646fae8..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala +++ /dev/null @@ -1,16 +0,0 @@ -package xyz.driver.core -package rest -import scala.annotation.implicitNotFound - -@implicitNotFound( - "Don't know how to communicate with service ${S}. Make sure an implicit ServiceDescriptor is" + - "available. A good place to put one is in the service's companion object.") -trait ServiceDescriptor[S] { - - /** The service's name. Must be unique among all services. */ - def name: String - - /** Get an instance of the service. */ - def connect(transport: HttpRestServiceTransport, url: String): S - -} diff --git a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala deleted file mode 100644 index 964a5a2..0000000 --- a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala +++ /dev/null @@ -1,29 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.ActorMaterializer -import xyz.driver.core.Name - -import scala.concurrent.Future - -class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) - extends HttpClient { - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - private val client = Http()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - client.singleRequest(request, settings = connectionPoolSettings) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala deleted file mode 100644 index 5ceac54..0000000 --- a/src/main/scala/xyz/driver/core/rest/Swagger.scala +++ /dev/null @@ -1,144 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} -import akka.http.scaladsl.server.Route -import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile -import akka.stream.ActorAttributes -import akka.stream.scaladsl.{Framing, StreamConverters} -import akka.util.ByteString -import com.github.swagger.akka.SwaggerHttpService -import com.github.swagger.akka.model._ -import com.typesafe.config.Config -import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme -import io.swagger.models.auth.{ApiKeyAuthDefinition, In} -import io.swagger.util.Json - -import scala.util.control.NonFatal - -class Swagger( - override val host: String, - accessSchemes: List[String], - version: String, - override val apiClasses: Set[Class[_]], - val config: Config, - val logger: Logger) - extends SwaggerHttpService { - - override val schemes = accessSchemes.map { s => - Scheme.forValue(s) - } - - // Note that the reason for overriding this is a subtle chain of causality: - // - // 1. Some of our endpoints require a single trailing slash and will not - // function if it is omitted - // 2. Swagger omits trailing slashes in its generated api doc - // 3. To work around that, a space is added after the trailing slash in the - // swagger Path annotations - // 4. This space is removed manually in the code below - // - // TODO: Ideally we'd like to drop this custom override and fix the issue in - // 1, by dropping the slash requirement and accepting api endpoints with and - // without trailing slashes. This will require inspecting and potentially - // fixing all service endpoints. - override def generateSwaggerJson: String = { - import io.swagger.models.{Swagger => JSwagger} - - import scala.collection.JavaConverters._ - try { - val swagger: JSwagger = reader.read(apiClasses.asJava) - - val paths = if (swagger.getPaths == null) { - Map.empty - } else { - swagger.getPaths.asScala - } - - // Removing trailing spaces - val fixedPaths = paths.map { - case (key, path) => - key.trim -> path - } - - swagger.setPaths(fixedPaths.asJava) - - Json.pretty().writeValueAsString(swagger) - } catch { - case NonFatal(t) => - logger.error("Issue with creating swagger.json", t) - throw t - } - } - - override val securitySchemeDefinitions = Map( - "token" -> { - val definition = new ApiKeyAuthDefinition("Authorization", In.HEADER) - definition.setDescription("Authentication token") - definition - } - ) - - override val basePath: String = config.getString("swagger.basePath") - override val apiDocsPath: String = config.getString("swagger.docsPath") - - override val info = Info( - config.getString("swagger.apiInfo.description"), - version, - config.getString("swagger.apiInfo.title"), - config.getString("swagger.apiInfo.termsOfServiceUrl"), - contact = Some( - Contact( - config.getString("swagger.apiInfo.contact.name"), - config.getString("swagger.apiInfo.contact.url"), - config.getString("swagger.apiInfo.contact.email") - )), - license = Some( - License( - config.getString("swagger.apiInfo.license"), - config.getString("swagger.apiInfo.licenseUrl") - )), - vendorExtensions = Map.empty[String, AnyRef] - ) - - /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ - private def getTemplatedResource( - resourceName: String, - contentType: ContentType, - substitution: (String, String)): Route = get { - Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { - case Some(ResourceFile(url, length @ _, _)) => - extractSettings { settings => - val stream = StreamConverters - .fromInputStream(() => url.openStream()) - .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) - .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) - .map { line => - line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) - } - .map(line => ByteString(line + "\n")) - complete( - HttpEntity(contentType, stream) - ) - } - case None => reject - } - } - - def swaggerUI: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui") - - def swaggerUINew: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui-dist/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui-dist") - -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala deleted file mode 100644 index 5007774..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val permissionsMap = permissions.map(_ -> true).toMap - Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala deleted file mode 100644 index e1a94e1..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala +++ /dev/null @@ -1,75 +0,0 @@ -package xyz.driver.core.rest.auth - -import akka.http.scaladsl.server.directives.Credentials -import com.typesafe.scalalogging.Logger -import scalaz.OptionT -import xyz.driver.core.auth.{AuthToken, Permission, User} -import xyz.driver.core.rest.errors.{ExternalServiceException, UnauthorizedException} -import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ContextHeaders, ServiceRequestContext, serviceContext} - -import scala.concurrent.{ExecutionContext, Future} - -abstract class AuthProvider[U <: User]( - val authorization: Authorization[U], - log: Logger, - val realm: String -)(implicit execution: ExecutionContext) { - - import akka.http.scaladsl.server._ - import Directives.{authorize => akkaAuthorize, _} - - def this(authorization: Authorization[U], log: Logger)(implicit executionContext: ExecutionContext) = - this(authorization, log, "driver.xyz") - - /** - * Specific implementation on how to extract user from request context, - * can either need to do a network call to auth server or extract everything from self-contained token - * - * @param ctx set of request values which can be relevant to authenticate user - * @return authenticated user - */ - def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] - - protected def authenticator(context: ServiceRequestContext): AsyncAuthenticator[U] = { - case Credentials.Missing => - log.info(s"Request (${context.trackingId}) missing authentication credentials") - Future.successful(None) - case Credentials.Provided(authToken) => - authenticatedUser(context.withAuthToken(AuthToken(authToken))).run.recover({ - case ExternalServiceException(_, _, Some(UnauthorizedException(_))) => None - }) - } - - /** - * Verifies that a user agent is properly authenticated, and (optionally) authorized with the specified permissions - */ - def authorize( - context: ServiceRequestContext, - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authenticateOAuth2Async[U](realm, authenticator(context)) flatMap { authenticatedUser => - val authCtx = context.withAuthenticatedUser(context.authToken.get, authenticatedUser) - onSuccess(authorization.userHasPermissions(authenticatedUser, permissions)(authCtx)) flatMap { - case AuthorizationResult(authorized, token) => - val allAuthorized = permissions.forall(authorized.getOrElse(_, false)) - akkaAuthorize(allAuthorized) tflatMap { _ => - val cachedPermissionsCtx = token.fold(authCtx)(authCtx.withPermissionsToken) - provide(cachedPermissionsCtx) - } - } - } - } - - /** - * Verifies if request is authenticated and authorized to have `permissions` - */ - def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - serviceContext flatMap (authorize(_, permissions: _*)) - } -} - -object AuthProvider { - val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader: String = "set-authorization" - val SetPermissionsTokenHeader: String = "set-permissions" -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala deleted file mode 100644 index 1a5e9be..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -trait Authorization[U <: User] { - def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala deleted file mode 100644 index efe28c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala +++ /dev/null @@ -1,22 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, PermissionsToken} - -import scalaz.Scalaz.mapMonoid -import scalaz.Semigroup -import scalaz.syntax.semigroup._ - -final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) -object AuthorizationResult { - val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) - - implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { - private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) - private implicit val permissionsTokenSemigroup = - Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) - - override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { - AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala deleted file mode 100644 index 66de4ef..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala +++ /dev/null @@ -1,55 +0,0 @@ -package xyz.driver.core.rest.auth - -import java.nio.file.{Files, Path} -import java.security.{KeyFactory, PublicKey} -import java.security.spec.X509EncodedKeySpec - -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future -import scalaz.syntax.std.boolean._ - -class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - import spray.json._ - - def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = - tokenObject.fields.get("permissions").collect { - case JsObject(fields) => - fields.collect { - case (key, JsBoolean(value)) => key -> value - } - } - - val result = for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) - - permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) - - authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap - } yield AuthorizationResult(authorized, Some(token)) - - Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) - } -} - -object CachedTokenAuthorization { - def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { - lazy val publicKey: PublicKey = { - val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim - val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) - val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) - KeyFactory.getInstance("RSA").generatePublic(spec) - } - new CachedTokenAuthorization[U](publicKey, issuer) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala deleted file mode 100644 index 131e7fc..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.Scalaz.{futureInstance, listInstance} -import scalaz.syntax.semigroup._ -import scalaz.syntax.traverse._ - -class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) - extends Authorization[U] { - - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = - permissions.forall(permissionsMap.getOrElse(_, false)) - - authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { - (authResult, authorization) => - if (allAuthorized(authResult.authorized)) Future.successful(authResult) - else { - authorization.userHasPermissions(user, permissions).map(authResult |+| _) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala deleted file mode 100644 index ff3424d..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala +++ /dev/null @@ -1,19 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.auth.AuthProvider - -/** Authentication and authorization directives. */ -trait AuthDirectives extends AkkaDirectives { - - /** Authenticate a user based on service request headers and check if they have all given permissions. */ - def authenticateAndAuthorize[U <: User]( - authProvider: AuthProvider[U], - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authProvider.authorize(permissions: _*) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala deleted file mode 100644 index 5a6bbfd..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala +++ /dev/null @@ -1,72 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.model.HttpMethods._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{HttpResponse, StatusCodes} -import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} - -/** Directives to handle Cross-Origin Resource Sharing (CORS). */ -trait CorsDirectives extends AkkaDirectives { - - /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request - * origin. - * - * In a microservice environment, it can be difficult to know in advance the exact origin - * from which requests may be issued [1]. For example, the request may come from a web page served from - * any of the services, on any namespace or from other documentation sites. In general, only a set - * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that - * implement CORS require exact specification of allowed origins, including full host name and scheme, - * in order to send credentials and headers with requests to other origins. - * - * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by - * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. - * - * Note that the simplicity of this implementation comes with two notable drawbacks: - * - * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. - * - * - Allowed methods and headers can not be customized on a per-request basis. All standard - * HTTP methods are allowed, and allowed headers are specified for all inner routes. - * - * This handler is not suited for cases where more fine-grained control of responses is required. - * - * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through - * a common gateway. - * - * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed - * origins. - * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. - * @param inner Route into which CORS headers will be injected. - */ - def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { - optionalHeaderValueByType[Origin](()) { maybeOrigin => - val allowedOrigins: HttpOriginRange = maybeOrigin match { - // Note that this is not a security issue: the client will never send credentials if the allowed - // origin is set to *. This case allows us to deal with clients that do not send an origin header. - case None => HttpOriginRange.* - case Some(requestOrigin) => - val allowedOrigin = requestOrigin.origins.find(origin => - allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) - allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) - } - - respondWithHeaders( - `Access-Control-Allow-Origin`.forRange(allowedOrigins), - `Access-Control-Allow-Credentials`(true), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*) - ) { - options { // options is used during preflight check - complete( - HttpResponse(StatusCodes.OK) - .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) - } ~ inner // in case of non-preflight check we don't do anything special - } - } - } - -} - -object CorsDirectives extends CorsDirectives diff --git a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala deleted file mode 100644 index 0cd4ef1..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala +++ /dev/null @@ -1,6 +0,0 @@ -package xyz.driver.core -package rest -package directives - -trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers -object Directives extends Directives diff --git a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala deleted file mode 100644 index 218c9ae..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala +++ /dev/null @@ -1,85 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.time.Instant -import java.util.UUID - -import akka.http.scaladsl.model.Uri.Path -import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} -import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import xyz.driver.core.domain.PhoneNumber -import xyz.driver.core.time.Time - -import scala.util.control.NonFatal - -/** Akka-HTTP path matchers for custom core types. */ -trait PathMatchers { - - private def UuidInPath[T]: PathMatcher1[Id[T]] = - AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) - - def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) - case _ => Unmatched - } - } - - def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) - case _ => Unmatched - } - } - - private def timestampInPath: PathMatcher1[Long] = - PathMatcher("""[+-]?\d*""".r) flatMap { string => - try Some(string.toLong) - catch { case _: IllegalArgumentException => None } - } - - def InstantInPath: PathMatcher1[Instant] = - new PathMatcher1[Instant] { - def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { - case Path.Segment(head, tail) => - try Matched(tail, Tuple1(Instant.parse(head))) - catch { - case NonFatal(_) => Unmatched - } - case _ => Unmatched - } - } | timestampInPath.map(Instant.ofEpochMilli) - - def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) - - def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - refineV[NonEmpty](segment) match { - case Left(_) => Unmatched - case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) - } - case _ => Unmatched - } - } - - def RevisionInPath[T]: PathMatcher1[Revision[T]] = - PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => - Some(Revision[T](string)) - } - - def PhoneInPath: PathMatcher1[PhoneNumber] = new PathMatcher1[PhoneNumber] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - PhoneNumber - .parse(segment) - .map(parsed => Matched(tail, Tuple1(parsed))) - .getOrElse(Unmatched) - case _ => Unmatched - } - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala deleted file mode 100644 index 6c45d15..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala +++ /dev/null @@ -1,40 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.util.UUID - -import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} -import akka.http.scaladsl.unmarshalling.Unmarshaller -import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} - -/** Akka-HTTP unmarshallers for custom core types. */ -trait Unmarshallers { - - implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = - Unmarshaller.strict[String, Id[A]] { str => - Id[A](UUID.fromString(str).toString) - } - - implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = - Unmarshaller.firstOf( - Unmarshaller.strict((JsString(_: String)) andThen reader.read), - stringToValueUnmarshaller[T] - ) - - implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = - Unmarshaller.strict[String, Revision[T]](Revision[T]) - - val jsValueToStringMarshaller: Marshaller[JsValue, String] = - Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) - - def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = - jsValueToStringMarshaller.compose[T](jsonFormat.write) - - val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = - Unmarshaller.strict[String, JsValue](value => JsonParser(value)) - - def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = - stringToJsValueUnmarshaller.map[T](jsonFormat.read) - -} diff --git a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala deleted file mode 100644 index f2962c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.errors - -sealed abstract class ServiceException(val message: String) extends Exception(message) - -final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) - -final case class InvalidActionException(override val message: String = "This action is not allowed") - extends ServiceException(message) - -final case class UnauthorizedException( - override val message: String = "The user's authentication credentials are invalid or missing") - extends ServiceException(message) - -final case class ResourceNotFoundException(override val message: String = "Resource not found") - extends ServiceException(message) - -final case class ExternalServiceException( - serviceName: String, - serviceMessage: String, - serviceException: Option[ServiceException]) - extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") - -final case class ExternalServiceTimeoutException(serviceName: String) - extends ServiceException(s"$serviceName took too long to respond") - -final case class DatabaseException(override val message: String = "Database access error") - extends ServiceException(message) diff --git a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala deleted file mode 100644 index 866476d..0000000 --- a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala +++ /dev/null @@ -1,33 +0,0 @@ -package xyz.driver.core -package rest -package headers - -import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} -import xyz.driver.core.reporting.SpanContext - -import scala.util.Try - -/** Encapsulates a trace context in an HTTP header for propagation across services. - * - * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) - * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is - * ignored. - */ -final case class Traceparent(spanContext: SpanContext) extends ModeledCustomHeader[Traceparent] { - override def renderInRequests = true - override def renderInResponses = true - override val companion: Traceparent.type = Traceparent - override def value: String = f"01-${spanContext.traceId}-${spanContext.spanId}-00" -} -object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { - override val name = "traceparent" - override def parse(value: String) = Try { - val Array(version, traceId, spanId, _) = value.split("-") - require( - version == "01", - s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") - new Traceparent( - new SpanContext(traceId, spanId) - ) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala deleted file mode 100644 index 34a4a9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ /dev/null @@ -1,323 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.Flow -import akka.util.ByteString -import scalaz.Scalaz.{intInstance, stringInstance} -import scalaz.syntax.equal._ -import scalaz.{Functor, OptionT} -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.errors.ExternalServiceException -import xyz.driver.core.rest.headers.Traceparent -import xyz.driver.tracing.TracingDirectives - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.Try - -trait Service - -object Service - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] -} - -sealed trait SortingOrder -object SortingOrder { - case object Asc extends SortingOrder - case object Desc extends SortingOrder -} - -final case class SortingField(name: String, sortingOrder: SortingOrder) -final case class Sorting(sortingFields: Seq[SortingField]) - -final case class Pagination(pageSize: Int, pageNumber: Int) { - require(pageSize > 0, "Page size must be greater than zero") - require(pageNumber > 0, "Page number must be greater than zero") - - def offset: Int = pageSize * (pageNumber - 1) -} - -final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) - -object ListResponse { - - def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = - ListResponse( - items = items, - meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) - - final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) - - object Meta { - def apply(itemsCount: Int, pagination: Pagination): Meta = - Meta(itemsCount, pagination.pageNumber, pagination.pageSize) - } - -} - -object `package` { - - implicit class FutureExtensions[T](future: Future[T]) { - def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = - future.transform(identity, { - case ExternalServiceException(_, _, Some(e)) => e - case t: Throwable => t - }) - } - - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } - - object ContextHeaders { - val AuthenticationTokenHeader: String = "Authorization" - val PermissionsTokenHeader: String = "Permissions" - val AuthenticationHeaderPrefix: String = "Bearer" - val ClientFingerprintHeader: String = "X-Client-Fingerprint" - val TrackingIdHeader: String = "X-Trace" - val StacktraceHeader: String = "X-Stacktrace" - val OriginatingIpHeader: String = "X-Forwarded-For" - val ResourceCount: String = "X-Resource-Count" - val PageCount: String = "X-Page-Count" - val TraceHeaderName: String = TracingDirectives.TraceHeaderName - val SpanHeaderName: String = TracingDirectives.SpanHeaderName - } - - val AllowedHeaders: Seq[String] = - Seq( - "Origin", - "X-Requested-With", - "Content-Type", - "Content-Length", - "Accept", - "X-Trace", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Headers", - "Server", - "Date", - ContextHeaders.ClientFingerprintHeader, - ContextHeaders.TrackingIdHeader, - ContextHeaders.TraceHeaderName, - ContextHeaders.SpanHeaderName, - ContextHeaders.StacktraceHeader, - ContextHeaders.AuthenticationTokenHeader, - ContextHeaders.OriginatingIpHeader, - ContextHeaders.ResourceCount, - ContextHeaders.PageCount, - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader, - "Traceparent" - ) - - def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - - def serviceContext: Directive1[ServiceRequestContext] = { - def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { - headers.map({ header => - if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( - ContextHeaders.AuthenticationHeaderPrefix)) { - Authorization(OAuth2BearerToken(header.value)) - } else header - })(collection.breakOut) - } - extractClientIP flatMap { remoteAddress => - mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => - extract(ctx => extractServiceContext(ctx.request, remoteAddress)) - } - } - } - - def respondWithCorsAllowedHeaders: Directive0 = { - respondWithHeaders( - List[HttpHeader]( - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - )) - } - - def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { - respondWithHeader { - `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) - } - } - - def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { - respondWithHeaders( - List[HttpHeader]( - Allow(methods.to[collection.immutable.Seq]), - `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) - )) - } - - def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = - new ServiceRequestContext( - extractTrackingId(request), - extractOriginatingIP(request, remoteAddress), - extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name === ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractFingerprintHash(request: HttpRequest): Option[String] = { - request.headers - .find(_.name === ContextHeaders.ClientFingerprintHeader) - .map(_.value()) - } - - def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { - request.headers - .find(_.name === ContextHeaders.OriginatingIpHeader) - .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) - .orElse(remoteAddress.toOption) - } - - def extractStacktrace(request: HttpRequest): Array[String] = - request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers - .filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || - h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || - h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || - h.name === Traceparent.name - } - .map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } - .toMap - } - - private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { - @annotation.tailrec - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse - else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) - } - } - } - - val indices = dirtyIndices(0, Nil) - - indices.headOption.fold(byteString) { head => - val builder = ByteString.newBuilder - builder ++= byteString.take(head) - - (indices :+ byteString.length).sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - case Seq(_) => // Should not match; sliding on at least 2 elements - assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") - } - builder.result - } - } - - val sanitizeRequestEntity: Directive0 = { - mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) - } - - val paginated: Directive1[Pagination] = - parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) - - private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = - (pageSizeOpt, pageNumberOpt) match { - case (Some(size), Some(number)) => Option(Pagination(size, number)) - case (None, None) => Option.empty[Pagination] - case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") - } - - val optionalPagination: Directive1[Option[Pagination]] = - parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) - - def paginationQuery(pagination: Pagination) = - Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) - - def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( - implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { - optionalPagination { pagination => - onSuccess(handler(pagination)) { - case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => - val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) - val headers = List( - RawHeader(ContextHeaders.ResourceCount, count.toString), - RawHeader(ContextHeaders.PageCount, pageCount.toString)) - - respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) - } - } - } - - private def extractSorting(sortingString: Option[String]): Sorting = { - val sortingFields = sortingString.fold(Seq.empty[SortingField])( - _.split(",") - .filter(_.length > 0) - .map { sortingParam => - if (sortingParam.startsWith("-")) { - SortingField(sortingParam.substring(1), SortingOrder.Desc) - } else { - val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam - SortingField(fieldName, SortingOrder.Asc) - } - } - .toSeq) - - Sorting(sortingFields) - } - - val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) - - def sortingQuery(sorting: Sorting): Seq[(String, String)] = { - val sortingString = sorting.sortingFields - .map { sortingField => - sortingField.sortingOrder match { - case SortingOrder.Asc => sortingField.name - case SortingOrder.Desc => s"-${sortingField.name}" - } - } - .mkString(",") - Seq("sort" -> sortingString) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala deleted file mode 100644 index 55f1a2e..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala +++ /dev/null @@ -1,24 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -trait ServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T -} - -trait SavingUsedServiceDiscovery { - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } -} - -class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala deleted file mode 100644 index d2e4bc3..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala +++ /dev/null @@ -1,87 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} -import xyz.driver.core.generators -import scalaz.Scalaz.{mapEqual, stringInstance} -import scalaz.syntax.equal._ -import xyz.driver.core.reporting.SpanContext -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.headers.Traceparent - -import scala.util.Try - -class ServiceRequestContext( - val trackingId: String = generators.nextUuid().toString, - val originatingIp: Option[InetAddress] = None, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - - def permissionsToken: Option[PermissionsToken] = - contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - - def withAuthToken(authToken: AuthToken): ServiceRequestContext = - new ServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) - ) - - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user - ) - - override def hashCode(): Int = - Seq[Any](trackingId, originatingIp, contextHeaders) - .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) - - override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => - trackingId === ctx.trackingId && - originatingIp == originatingIp && - contextHeaders === ctx.contextHeaders - case _ => false - } - - def spanContext: SpanContext = { - val validHeader = Try { - contextHeaders(Traceparent.name) - }.flatMap { value => - Traceparent.parse(value) - } - validHeader.map(_.spanContext).getOrElse(SpanContext.fresh()) - } - - override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" -} - -class AuthorizedServiceRequestContext[U <: User]( - override val trackingId: String = generators.nextUuid().toString, - override val originatingIp: Option[InetAddress] = None, - override val contextHeaders: Map[String, String] = Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} -- cgit v1.2.3