diff options
author | Jakob Odersky <jakob@driver.xyz> | 2018-09-12 16:40:57 -0700 |
---|---|---|
committer | Jakob Odersky <jakob@odersky.com> | 2018-10-09 16:19:39 -0700 |
commit | 4d1197099ce4e721c18bf4cacbb2e1980e4210b5 (patch) | |
tree | 1d3c2e482d5acf8ca613ba65fb5401c41ad686b5 /src | |
parent | 7c755c77afbd67ae2ded9d8b004736d4e27e208f (diff) | |
download | driver-core-4d1197099ce4e721c18bf4cacbb2e1980e4210b5.tar.gz driver-core-4d1197099ce4e721c18bf4cacbb2e1980e4210b5.tar.bz2 driver-core-4d1197099ce4e721c18bf4cacbb2e1980e4210b5.zip |
Move REST functionality to separate project
Diffstat (limited to 'src')
37 files changed, 0 insertions, 3638 deletions
diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala deleted file mode 100644 index 896bd89..0000000 --- a/src/main/scala/xyz/driver/core/auth.scala +++ /dev/null @@ -1,43 +0,0 @@ -package xyz.driver.core - -import xyz.driver.core.domain.Email -import xyz.driver.core.time.Time -import scalaz.Equal - -object auth { - - trait Permission - - final case class Role(id: Id[Role], name: Name[Role]) { - - def oneOf(roles: Role*): Boolean = roles.contains(this) - - def oneOf(roles: Set[Role]): Boolean = roles.contains(this) - } - - object Role { - implicit def idEqual: Equal[Role] = Equal.equal[Role](_ == _) - } - - trait User { - def id: Id[User] - } - - final case class AuthToken(value: String) - - final case class AuthTokenUserInfo( - id: Id[User], - email: Email, - emailVerified: Boolean, - audience: String, - roles: Set[Role], - expirationTime: Time) - extends User - - final case class RefreshToken(value: String) - final case class PermissionsToken(value: String) - - final case class PasswordHash(value: String) - - final case class AuthCredentials(identifier: String, password: String) -} diff --git a/src/main/scala/xyz/driver/core/generators.scala b/src/main/scala/xyz/driver/core/generators.scala deleted file mode 100644 index d00b6dd..0000000 --- a/src/main/scala/xyz/driver/core/generators.scala +++ /dev/null @@ -1,143 +0,0 @@ -package xyz.driver.core - -import enumeratum._ -import java.math.MathContext -import java.time.{Instant, LocalDate, ZoneOffset} -import java.util.UUID - -import xyz.driver.core.time.{Time, TimeOfDay, TimeRange} -import xyz.driver.core.date.{Date, DayOfWeek} - -import scala.reflect.ClassTag -import scala.util.Random -import eu.timepit.refined.refineV -import eu.timepit.refined.api.Refined -import eu.timepit.refined.collection._ - -object generators { - - private val random = new Random - import random._ - private val secureRandom = new java.security.SecureRandom() - - private val DefaultMaxLength = 10 - private val StringLetters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ".toSet - private val NonAmbigiousCharacters = "abcdefghijkmnpqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ23456789" - private val Numbers = "0123456789" - - private def nextTokenString(length: Int, chars: IndexedSeq[Char]): String = { - val builder = new StringBuilder - for (_ <- 0 until length) { - builder += chars(secureRandom.nextInt(chars.length)) - } - builder.result() - } - - /** Creates a random invitation token. - * - * This token is meant fo human input and avoids using ambiguous characters such as 'O' and '0'. It - * therefore contains less entropy and is not meant to be used as a cryptographic secret. */ - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextToken(length: Int): String = nextTokenString(length, NonAmbigiousCharacters) - - @deprecated( - "The term 'token' is too generic and security and readability conventions are not well defined. " + - "Services should implement their own version that suits their security requirements.", - "1.11.0" - ) - def nextNumericToken(length: Int): String = nextTokenString(length, Numbers) - - def nextInt(maxValue: Int, minValue: Int = 0): Int = random.nextInt(maxValue - minValue) + minValue - - def nextBoolean(): Boolean = random.nextBoolean() - - def nextDouble(): Double = random.nextDouble() - - def nextId[T](): Id[T] = Id[T](nextUuid().toString) - - def nextId[T](maxLength: Int): Id[T] = Id[T](nextString(maxLength)) - - def nextNumericId[T](): Id[T] = Id[T](nextLong.abs.toString) - - def nextNumericId[T](maxValue: Int): Id[T] = Id[T](nextInt(maxValue).toString) - - def nextName[T](maxLength: Int = DefaultMaxLength): Name[T] = Name[T](nextString(maxLength)) - - def nextNonEmptyName[T](maxLength: Int = DefaultMaxLength): NonEmptyName[T] = - NonEmptyName[T](nextNonEmptyString(maxLength)) - - def nextUuid(): UUID = java.util.UUID.randomUUID - - def nextRevision[T](): Revision[T] = Revision[T](nextUuid().toString) - - def nextString(maxLength: Int = DefaultMaxLength): String = - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - - def nextNonEmptyString(maxLength: Int = DefaultMaxLength): String Refined NonEmpty = { - refineV[NonEmpty]( - (oneOf[Char](StringLetters) +: arrayOf(oneOf[Char](StringLetters), maxLength - 1)).mkString - ).right.get - } - - def nextOption[T](value: => T): Option[T] = if (nextBoolean()) Option(value) else None - - def nextPair[L, R](left: => L, right: => R): (L, R) = (left, right) - - def nextTriad[F, S, T](first: => F, second: => S, third: => T): (F, S, T) = (first, second, third) - - def nextInstant(): Instant = Instant.ofEpochMilli(math.abs(nextLong() % System.currentTimeMillis)) - - def nextTime(): Time = nextInstant() - - def nextTimeOfDay: TimeOfDay = TimeOfDay(java.time.LocalTime.MIN.plusSeconds(nextLong), java.util.TimeZone.getDefault) - - def nextTimeRange(): TimeRange = { - val oneTime = nextTime() - val anotherTime = nextTime() - - TimeRange( - Time(scala.math.min(oneTime.millis, anotherTime.millis)), - Time(scala.math.max(oneTime.millis, anotherTime.millis))) - } - - def nextDate(): Date = nextTime().toDate(java.util.TimeZone.getTimeZone("UTC")) - - def nextLocalDate(): LocalDate = nextInstant().atZone(ZoneOffset.UTC).toLocalDate - - def nextDayOfWeek(): DayOfWeek = oneOf(DayOfWeek.All) - - def nextBigDecimal(multiplier: Double = 1000000.00, precision: Int = 2): BigDecimal = - BigDecimal(multiplier * nextDouble, new MathContext(precision)) - - def oneOf[T](items: T*): T = oneOf(items.toSet) - - def oneOf[T](items: Set[T]): T = items.toSeq(nextInt(items.size)) - - def oneOf[T <: EnumEntry](enum: Enum[T]): T = oneOf(enum.values: _*) - - def arrayOf[T: ClassTag](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Array[T] = - Array.fill(nextInt(maxLength, minLength))(generator) - - def seqOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Seq[T] = - Seq.fill(nextInt(maxLength, minLength))(generator) - - def vectorOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Vector[T] = - Vector.fill(nextInt(maxLength, minLength))(generator) - - def listOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): List[T] = - List.fill(nextInt(maxLength, minLength))(generator) - - def setOf[T](generator: => T, maxLength: Int = DefaultMaxLength, minLength: Int = 0): Set[T] = - seqOf(generator, maxLength, minLength).toSet - - def mapOf[K, V]( - keyGenerator: => K, - valueGenerator: => V, - maxLength: Int = DefaultMaxLength, - minLength: Int = 0): Map[K, V] = - seqOf(nextPair(keyGenerator, valueGenerator), maxLength, minLength).toMap -} diff --git a/src/main/scala/xyz/driver/core/json.scala b/src/main/scala/xyz/driver/core/json.scala deleted file mode 100644 index edc2347..0000000 --- a/src/main/scala/xyz/driver/core/json.scala +++ /dev/null @@ -1,398 +0,0 @@ -package xyz.driver.core - -import java.net.InetAddress -import java.time.format.DateTimeFormatter -import java.time.{Instant, LocalDate} -import java.util.{TimeZone, UUID} - -import akka.http.scaladsl.unmarshalling.Unmarshaller -import com.neovisionaries.i18n.{CountryCode, CurrencyCode} -import enumeratum._ -import eu.timepit.refined.api.{Refined, Validate} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import spray.json._ -import xyz.driver.core.auth.AuthCredentials -import xyz.driver.core.date.{Date, DayOfWeek, Month} -import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.rest.directives.{PathMatchers, Unmarshallers} -import xyz.driver.core.rest.errors._ -import xyz.driver.core.time.{Time, TimeOfDay} - -import scala.reflect.runtime.universe._ -import scala.reflect.{ClassTag, classTag} -import scala.util.Try -import scala.util.control.NonFatal - -object json extends PathMatchers with Unmarshallers { - import DefaultJsonProtocol._ - - implicit def idFormat[T]: RootJsonFormat[Id[T]] = new RootJsonFormat[Id[T]] { - def write(id: Id[T]) = JsString(id.value) - - def read(value: JsValue): Id[T] = value match { - case JsString(id) if Try(UUID.fromString(id)).isSuccess => Id[T](id.toLowerCase) - case JsString(id) => Id[T](id) - case _ => throw DeserializationException("Id expects string") - } - } - - implicit def taggedFormat[F, T](implicit underlying: JsonFormat[F], convert: F => F @@ T = null): JsonFormat[F @@ T] = - new JsonFormat[F @@ T] { - import tagging._ - - private val transformReadValue = Option(convert).getOrElse((_: F).tagged[T]) - - override def write(obj: F @@ T): JsValue = underlying.write(obj) - - override def read(json: JsValue): F @@ T = transformReadValue(underlying.read(json)) - } - - implicit def nameFormat[T] = new RootJsonFormat[Name[T]] { - def write(name: Name[T]) = JsString(name.value) - - def read(value: JsValue): Name[T] = value match { - case JsString(name) => Name[T](name) - case _ => throw DeserializationException("Name expects string") - } - } - - implicit val timeFormat: RootJsonFormat[Time] = new RootJsonFormat[Time] { - def write(time: Time) = JsObject("timestamp" -> JsNumber(time.millis)) - - def read(value: JsValue): Time = Time(instantFormat.read(value)) - } - - implicit val instantFormat: JsonFormat[Instant] = new JsonFormat[Instant] { - def write(instant: Instant): JsValue = JsString(instant.toString) - - def read(value: JsValue): Instant = value match { - case JsObject(fields) => - fields - .get("timestamp") - .flatMap { - case JsNumber(millis) => Some(Instant.ofEpochMilli(millis.longValue())) - case _ => None - } - .getOrElse(deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}")) - case JsNumber(millis) => Instant.ofEpochMilli(millis.longValue()) - case JsString(str) => - try Instant.parse(str) - catch { case NonFatal(_) => deserializationError(s"Instant expects ISO timestamp but got $str") } - case _ => deserializationError(s"Instant expects ISO timestamp but got ${value.compactPrint}") - } - } - - implicit object localTimeFormat extends JsonFormat[java.time.LocalTime] { - private val formatter = TimeOfDay.getFormatter - def read(json: JsValue): java.time.LocalTime = json match { - case JsString(chars) => - java.time.LocalTime.parse(chars) - case _ => deserializationError(s"Expected time string got ${json.toString}") - } - - def write(obj: java.time.LocalTime): JsValue = { - JsString(obj.format(formatter)) - } - } - - implicit object timeZoneFormat extends JsonFormat[java.util.TimeZone] { - override def write(obj: TimeZone): JsValue = { - JsString(obj.getID()) - } - - override def read(json: JsValue): TimeZone = json match { - case JsString(chars) => - java.util.TimeZone.getTimeZone(chars) - case _ => deserializationError(s"Expected time zone string got ${json.toString}") - } - } - - implicit val timeOfDayFormat: RootJsonFormat[TimeOfDay] = jsonFormat2(TimeOfDay.apply) - - implicit val dayOfWeekFormat: JsonFormat[DayOfWeek] = new enumeratum.EnumJsonFormat(DayOfWeek) - - implicit val dateFormat = new RootJsonFormat[Date] { - def write(date: Date) = JsString(date.toString) - def read(value: JsValue): Date = value match { - case JsString(dateString) => - Date - .fromString(dateString) - .getOrElse( - throw DeserializationException(s"Misformated ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.")) - case _ => throw DeserializationException(s"Date expects a string, but got $value.") - } - } - - implicit val localDateFormat = new RootJsonFormat[LocalDate] { - val format = DateTimeFormatter.ISO_LOCAL_DATE - - def write(date: LocalDate): JsValue = JsString(date.format(format)) - def read(value: JsValue): LocalDate = value match { - case JsString(dateString) => - try LocalDate.parse(dateString, format) - catch { - case NonFatal(_) => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got $dateString.") - } - case _ => - throw deserializationError(s"Malformed ISO 8601 Date. Expected YYYY-MM-DD, but got ${value.compactPrint}.") - } - } - - implicit val monthFormat = new RootJsonFormat[Month] { - def write(month: Month) = JsNumber(month) - def read(value: JsValue): Month = value match { - case JsNumber(month) if 0 <= month && month <= 11 => Month(month.toInt) - case _ => throw DeserializationException("Expected a number from 0 to 11") - } - } - - implicit def revisionFormat[T]: RootJsonFormat[Revision[T]] = new RootJsonFormat[Revision[T]] { - def write(revision: Revision[T]) = JsString(revision.id.toString) - - def read(value: JsValue): Revision[T] = value match { - case JsString(revision) => Revision[T](revision) - case _ => throw DeserializationException("Revision expects uuid string") - } - } - - implicit val base64Format = new RootJsonFormat[Base64] { - def write(base64Value: Base64) = JsString(base64Value.value) - - def read(value: JsValue): Base64 = value match { - case JsString(base64Value) => Base64(base64Value) - case _ => throw DeserializationException("Base64 format expects string") - } - } - - implicit val emailFormat = new RootJsonFormat[Email] { - def write(email: Email) = JsString(email.username + "@" + email.domain) - def read(json: JsValue): Email = json match { - - case JsString(value) => - Email.parse(value).getOrElse { - deserializationError("Expected '@' symbol in email string as Email, but got " + json.toString) - } - - case _ => - deserializationError("Expected string as Email, but got " + json.toString) - } - } - - implicit object phoneNumberFormat extends RootJsonFormat[PhoneNumber] { - - private val basicFormat = jsonFormat3(PhoneNumber.apply) - - def write(obj: PhoneNumber): JsValue = basicFormat.write(obj) - - def read(json: JsValue): PhoneNumber = { - val maybePhone = json match { - case JsString(number) => PhoneNumber.parse(number) - case obj: JsObject => PhoneNumber.parse(basicFormat.read(obj).toString) - case _ => None - } - maybePhone.getOrElse(deserializationError("Invalid phone number")) - } - } - - implicit val authCredentialsFormat = new RootJsonFormat[AuthCredentials] { - override def read(json: JsValue): AuthCredentials = { - json match { - case JsObject(fields) => - val emailField = fields.get("email") - val identifierField = fields.get("identifier") - val passwordField = fields.get("password") - - (emailField, identifierField, passwordField) match { - case (_, _, None) => - deserializationError("password field must be set") - case (Some(JsString(em)), _, Some(JsString(pw))) => - val email = Email.parse(em).getOrElse(throw deserializationError(s"failed to parse email $em")) - AuthCredentials(email.toString, pw) - case (_, Some(JsString(id)), Some(JsString(pw))) => AuthCredentials(id.toString, pw.toString) - case (None, None, _) => deserializationError("identifier must be provided") - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - case _ => deserializationError(s"failed to deserialize ${json.prettyPrint}") - } - } - - override def write(obj: AuthCredentials): JsValue = JsObject( - "identifier" -> JsString(obj.identifier), - "password" -> JsString(obj.password) - ) - } - - implicit object inetAddressFormat extends JsonFormat[InetAddress] { - override def read(json: JsValue): InetAddress = json match { - case JsString(ipString) => - Try(InetAddress.getByName(ipString)) - .getOrElse(deserializationError(s"Invalid IP Address: $ipString")) - case _ => deserializationError(s"Expected string for IP Address, got $json") - } - - override def write(obj: InetAddress): JsValue = - JsString(obj.getHostAddress) - } - - implicit val countryCodeFormat: JsonFormat[CountryCode] = javaEnumFormat[CountryCode] - - implicit val currencyCodeFormat: JsonFormat[CurrencyCode] = javaEnumFormat[CurrencyCode] - - object enumeratum { - - def enumUnmarshaller[T <: EnumEntry](enum: Enum[T]): Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - - trait HasJsonFormat[T <: EnumEntry] { enum: Enum[T] => - - implicit val format: JsonFormat[T] = new EnumJsonFormat(enum) - - implicit val unmarshaller: Unmarshaller[String, T] = - Unmarshaller.strict { value => - enum.withNameOption(value).getOrElse(unrecognizedValue(value, enum.values)) - } - } - - class EnumJsonFormat[T <: EnumEntry](enum: Enum[T]) extends JsonFormat[T] { - override def read(json: JsValue): T = json match { - case JsString(name) => enum.withNameOption(name).getOrElse(unrecognizedValue(name, enum.values)) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - - override def write(obj: T): JsValue = JsString(obj.entryName) - } - - private def unrecognizedValue(value: String, possibleValues: Seq[Any]): Nothing = - deserializationError(s"Unexpected value $value. Expected one of: ${possibleValues.mkString("[", ", ", "]")}") - } - - class EnumJsonFormat[T](mapping: (String, T)*) extends RootJsonFormat[T] { - private val map = mapping.toMap - - override def write(value: T): JsValue = { - map.find(_._2 == value).map(_._1) match { - case Some(name) => JsString(name) - case _ => serializationError(s"Value $value is not found in the mapping $map") - } - } - - override def read(json: JsValue): T = json match { - case JsString(name) => - map.getOrElse(name, throw DeserializationException(s"Value $name is not found in the mapping $map")) - case _ => deserializationError("Expected string as enumeration value, but got " + json.toString) - } - } - - def javaEnumFormat[T <: java.lang.Enum[_]: ClassTag]: JsonFormat[T] = { - val values = classTag[T].runtimeClass.asInstanceOf[Class[T]].getEnumConstants - new EnumJsonFormat[T](values.map(v => v.name() -> v): _*) - } - - class ValueClassFormat[T: TypeTag](writeValue: T => BigDecimal, create: BigDecimal => T) extends JsonFormat[T] { - def write(valueClass: T) = JsNumber(writeValue(valueClass)) - def read(json: JsValue): T = json match { - case JsNumber(value) => create(value) - case _ => deserializationError(s"Expected number as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - class GadtJsonFormat[T: TypeTag]( - typeField: String, - typeValue: PartialFunction[T, String], - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) - extends RootJsonFormat[T] { - - def write(value: T): JsValue = { - - val valueType = typeValue.applyOrElse(value, { v: T => - deserializationError(s"No Value type for this type of ${typeOf[T].getClass.getName}: " + v.toString) - }) - - val valueFormat = - jsonFormat.applyOrElse(valueType, { f: String => - deserializationError(s"No Json format for this type of $valueType") - }) - - valueFormat.asInstanceOf[JsonFormat[T]].write(value) match { - case JsObject(fields) => JsObject(fields ++ Map(typeField -> JsString(valueType))) - case _ => serializationError(s"${typeOf[T].getClass.getName} serialized not to a JSON object") - } - } - - def read(json: JsValue): T = json match { - case JsObject(fields) => - val valueJson = JsObject(fields.filterNot(_._1 == typeField)) - fields(typeField) match { - case JsString(valueType) => - val valueFormat = jsonFormat.applyOrElse(valueType, { t: String => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - }) - valueFormat.read(valueJson) - case _ => - deserializationError(s"Unknown ${typeOf[T].getClass.getName} type ${fields(typeField)}") - } - case _ => - deserializationError(s"Expected Json Object as ${typeOf[T].getClass.getName}, but got " + json.toString) - } - } - - object GadtJsonFormat { - - def create[T: TypeTag](typeField: String)(typeValue: PartialFunction[T, String])( - jsonFormat: PartialFunction[String, JsonFormat[_ <: T]]) = { - - new GadtJsonFormat[T](typeField, typeValue, jsonFormat) - } - } - - /** - * Provides the JsonFormat for the Refined types provided by the Refined library. - * - * @see https://github.com/fthomas/refined - */ - implicit def refinedJsonFormat[T, Predicate]( - implicit valueFormat: JsonFormat[T], - validate: Validate[T, Predicate]): JsonFormat[Refined[T, Predicate]] = - new JsonFormat[Refined[T, Predicate]] { - def write(x: T Refined Predicate): JsValue = valueFormat.write(x.value) - def read(value: JsValue): T Refined Predicate = { - refineV[Predicate](valueFormat.read(value))(validate) match { - case Right(refinedValue) => refinedValue - case Left(refinementError) => deserializationError(refinementError) - } - } - } - - implicit def nonEmptyNameFormat[T](implicit nonEmptyStringFormat: JsonFormat[Refined[String, NonEmpty]]) = - new RootJsonFormat[NonEmptyName[T]] { - def write(name: NonEmptyName[T]) = JsString(name.value.value) - - def read(value: JsValue): NonEmptyName[T] = - NonEmptyName[T](nonEmptyStringFormat.read(value)) - } - - implicit val serviceExceptionFormat: RootJsonFormat[ServiceException] = - GadtJsonFormat.create[ServiceException]("type") { - case _: InvalidInputException => "InvalidInputException" - case _: InvalidActionException => "InvalidActionException" - case _: UnauthorizedException => "UnauthorizedException" - case _: ResourceNotFoundException => "ResourceNotFoundException" - case _: ExternalServiceException => "ExternalServiceException" - case _: ExternalServiceTimeoutException => "ExternalServiceTimeoutException" - case _: DatabaseException => "DatabaseException" - } { - case "InvalidInputException" => jsonFormat(InvalidInputException, "message") - case "InvalidActionException" => jsonFormat(InvalidActionException, "message") - case "UnauthorizedException" => jsonFormat(UnauthorizedException, "message") - case "ResourceNotFoundException" => jsonFormat(ResourceNotFoundException, "message") - case "ExternalServiceException" => - jsonFormat(ExternalServiceException, "serviceName", "serviceMessage", "serviceException") - case "ExternalServiceTimeoutException" => jsonFormat(ExternalServiceTimeoutException, "message") - case "DatabaseException" => jsonFormat(DatabaseException, "message") - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala b/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala deleted file mode 100644 index 87946e4..0000000 --- a/src/main/scala/xyz/driver/core/rest/DnsDiscovery.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core -package rest - -class DnsDiscovery(transport: HttpRestServiceTransport, overrides: Map[String, String]) { - - def discover[A](implicit descriptor: ServiceDescriptor[A]): A = { - val url = overrides.getOrElse(descriptor.name, s"https://${descriptor.name}") - descriptor.connect(transport, url) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala b/src/main/scala/xyz/driver/core/rest/DriverRoute.scala deleted file mode 100644 index 911e306..0000000 --- a/src/main/scala/xyz/driver/core/rest/DriverRoute.scala +++ /dev/null @@ -1,122 +0,0 @@ -package xyz.driver.core.rest - -import java.sql.SQLException - -import akka.http.scaladsl.model.headers.CacheDirectives.`no-cache` -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{StatusCodes, _} -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import com.typesafe.scalalogging.Logger -import org.slf4j.MDC -import xyz.driver.core.rest -import xyz.driver.core.rest.errors._ - -import scala.compat.Platform.ConcurrentModificationException - -trait DriverRoute { - def log: Logger - - def route: Route - - def routeWithDefaults: Route = { - (defaultResponseHeaders & handleExceptions(ExceptionHandler(exceptionHandler))) { - route - } - } - - protected def defaultResponseHeaders: Directive0 = { - extractRequest flatMap { request => - // Needs to happen before any request processing, so all the log messages - // associated with processing of this request are having this `trackingId` - val trackingId = rest.extractTrackingId(request) - val tracingHeader = RawHeader(ContextHeaders.TrackingIdHeader, trackingId) - MDC.put("trackingId", trackingId) - - respondWithHeaders(tracingHeader +: DriverRoute.DefaultHeaders: _*) - } - } - - /** - * Override me for custom exception handling - * - * @return Exception handling route for exception type - */ - protected def exceptionHandler: PartialFunction[Throwable, Route] = { - case serviceException: ServiceException => - serviceExceptionHandler(serviceException) - - case is: IllegalStateException => - ctx => - log.warn(s"Request is not allowed to ${ctx.request.method} ${ctx.request.uri}", is) - errorResponse(StatusCodes.BadRequest, message = is.getMessage, is)(ctx) - - case cm: ConcurrentModificationException => - ctx => - log.warn(s"Concurrent modification of the resource ${ctx.request.method} ${ctx.request.uri}", cm) - errorResponse(StatusCodes.Conflict, "Resource was changed concurrently, try requesting a newer version", cm)( - ctx) - - case se: SQLException => - ctx => - log.warn(s"Database exception for the resource ${ctx.request.method} ${ctx.request.uri}", se) - errorResponse(StatusCodes.InternalServerError, "Data access error", se)(ctx) - - case t: Exception => - ctx => - log.warn(s"Request to ${ctx.request.method} ${ctx.request.uri} could not be handled normally", t) - errorResponse(StatusCodes.InternalServerError, t.getMessage, t)(ctx) - } - - protected def serviceExceptionHandler(serviceException: ServiceException): Route = { - val statusCode = serviceException match { - case e: InvalidInputException => - log.info("Invalid client input error", e) - StatusCodes.BadRequest - case e: InvalidActionException => - log.info("Invalid client action error", e) - StatusCodes.Forbidden - case e: UnauthorizedException => - log.info("Unauthorized user error", e) - StatusCodes.Unauthorized - case e: ResourceNotFoundException => - log.info("Resource not found error", e) - StatusCodes.NotFound - case e: ExternalServiceException => - log.error("Error while calling another service", e) - StatusCodes.InternalServerError - case e: ExternalServiceTimeoutException => - log.error("Service timeout error", e) - StatusCodes.GatewayTimeout - case e: DatabaseException => - log.error("Database error", e) - StatusCodes.InternalServerError - } - - { (ctx: RequestContext) => - import xyz.driver.core.json.serviceExceptionFormat - val entity = - HttpEntity(ContentTypes.`application/json`, serviceExceptionFormat.write(serviceException).toString()) - errorResponse(statusCode, entity, serviceException)(ctx) - } - } - - protected def errorResponse[T <: Exception](statusCode: StatusCode, message: String, exception: T): Route = - errorResponse(statusCode, HttpEntity(message), exception) - - protected def errorResponse[T <: Exception](statusCode: StatusCode, entity: ResponseEntity, exception: T): Route = { - complete(HttpResponse(statusCode, entity = entity)) - } - -} - -object DriverRoute { - val DefaultHeaders: List[HttpHeader] = List( - // This header will eliminate the risk of envoy trying to reuse a connection - // that already timed out on the server side by completely rejecting keep-alive - Connection("close"), - // These 2 headers are the simplest way to prevent IE from caching GET requests - RawHeader("Pragma", "no-cache"), - `Cache-Control`(List(`no-cache`(Nil))) - ) -} diff --git a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala b/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala deleted file mode 100644 index e31635b..0000000 --- a/src/main/scala/xyz/driver/core/rest/HttpRestServiceTransport.scala +++ /dev/null @@ -1,103 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.RawHeader -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.TcpIdleTimeoutException -import org.slf4j.MDC -import xyz.driver.core.Name -import xyz.driver.core.reporting.Reporter -import xyz.driver.core.rest.errors.{ExternalServiceException, ExternalServiceTimeoutException} - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.{Failure, Success} - -class HttpRestServiceTransport( - applicationName: Name[App], - applicationVersion: String, - val actorSystem: ActorSystem, - val executionContext: ExecutionContext, - reporter: Reporter) - extends ServiceTransport { - - protected implicit val execution: ExecutionContext = executionContext - - protected val httpClient: HttpClient = new SingleRequestHttpClient(applicationName, applicationVersion, actorSystem) - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] = { - val tags = Map( - // open tracing semantic tags - "span.kind" -> "client", - "service" -> applicationName.value, - "http.url" -> requestStub.uri.toString, - "http.method" -> requestStub.method.value, - "peer.hostname" -> requestStub.uri.authority.host.toString, - // google's tracing console provides extra search features if we define these tags - "/http/path" -> requestStub.uri.path.toString, - "/http/method" -> requestStub.method.value.toString, - "/http/url" -> requestStub.uri.toString - ) - reporter.traceAsync(s"http_call_rpc", tags) { implicit span => - val requestTime = System.currentTimeMillis() - - val request = requestStub - .withHeaders(context.contextHeaders.toSeq.map { - case (ContextHeaders.TrackingIdHeader, _) => - RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId) - case (ContextHeaders.StacktraceHeader, _) => - RawHeader( - ContextHeaders.StacktraceHeader, - Option(MDC.get("stack")) - .orElse(context.contextHeaders.get(ContextHeaders.StacktraceHeader)) - .getOrElse("")) - case (header, headerValue) => RawHeader(header, headerValue) - }: _*) - - reporter.debug(s"Sending request to ${request.method} ${request.uri}") - - val response = httpClient.makeRequest(request) - - response.onComplete { - case Success(r) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.debug( - s"Response from ${request.uri} to request $requestStub is successful in $responseLatency ms: $r") - - case Failure(t: Throwable) => - val responseLatency = System.currentTimeMillis() - requestTime - reporter.warn( - s"Failed to receive response from ${request.method.value} ${request.uri} in $responseLatency ms", - t) - }(executionContext) - - response.recoverWith { - case _: TcpIdleTimeoutException => - val serviceCalled = s"${requestStub.method.value} ${requestStub.uri}" - Future.failed(ExternalServiceTimeoutException(serviceCalled)) - case t: Throwable => Future.failed(t) - } - }(context.spanContext) - } - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] = { - - sendRequestGetResponse(context)(requestStub) flatMap { response => - if (response.status == StatusCodes.NotFound) { - Future.successful(Unmarshal(HttpEntity.Empty: ResponseEntity)) - } else if (response.status.isFailure()) { - val serviceCalled = s"${requestStub.method} ${requestStub.uri}" - Unmarshal(response.entity).to[String] flatMap { errorString => - import spray.json._ - import xyz.driver.core.json._ - val serviceException = util.Try(serviceExceptionFormat.read(errorString.parseJson)).toOption - Future.failed(ExternalServiceException(serviceCalled, errorString, serviceException)) - } - } else { - Future.successful(Unmarshal(response.entity)) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala b/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala deleted file mode 100644 index f33bf9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/PatchDirectives.scala +++ /dev/null @@ -1,104 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.javadsl.server.Rejections -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.{ContentTypeRange, HttpCharsets, MediaType} -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.{FromEntityUnmarshaller, Unmarshaller} -import spray.json._ - -import scala.concurrent.Future -import scala.util.{Failure, Success, Try} - -trait PatchDirectives extends Directives with SprayJsonSupport { - - /** Media type for patches to JSON values, as specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - val `application/merge-patch+json`: MediaType.WithFixedCharset = - MediaType.applicationWithFixedCharset("merge-patch+json", HttpCharsets.`UTF-8`) - - /** Wraps a JSON value that represents a patch. - * The patch must given in the format specified in [[https://tools.ietf.org/html/rfc7396 RFC 7396]]. */ - case class PatchValue(value: JsValue) { - - /** Applies this patch to a given original JSON value. In other words, merges the original with this "diff". */ - def applyTo(original: JsValue): JsValue = mergeJsValues(original, value) - } - - /** Witness that the given patch may be applied to an original domain value. - * @tparam A type of the domain value - * @param patch the patch that may be applied to a domain value - * @param format a JSON format that enables serialization and deserialization of a domain value */ - case class Patchable[A](patch: PatchValue, format: RootJsonFormat[A]) { - - /** Applies the patch to a given domain object. The result will be a combination - * of the original value, updates with the fields specified in this witness' patch. */ - def applyTo(original: A): A = { - val serialized = format.write(original) - val merged = patch.applyTo(serialized) - val deserialized = format.read(merged) - deserialized - } - } - - implicit def patchValueUnmarshaller: FromEntityUnmarshaller[PatchValue] = - Unmarshaller.byteStringUnmarshaller - .andThen(sprayJsValueByteStringUnmarshaller) - .forContentTypes(ContentTypeRange(`application/merge-patch+json`)) - .map(js => PatchValue(js)) - - implicit def patchableUnmarshaller[A]( - implicit patchUnmarshaller: FromEntityUnmarshaller[PatchValue], - format: RootJsonFormat[A]): FromEntityUnmarshaller[Patchable[A]] = { - patchUnmarshaller.map(patch => Patchable[A](patch, format)) - } - - protected def mergeObjects(oldObj: JsObject, newObj: JsObject, maxLevels: Option[Int] = None): JsObject = { - JsObject((oldObj.fields.keys ++ newObj.fields.keys).map({ key => - val oldValue = oldObj.fields.getOrElse(key, JsNull) - val newValue = newObj.fields.get(key).fold(oldValue)(mergeJsValues(oldValue, _, maxLevels.map(_ - 1))) - key -> newValue - })(collection.breakOut): _*) - } - - protected def mergeJsValues(oldValue: JsValue, newValue: JsValue, maxLevels: Option[Int] = None): JsValue = { - def mergeError(typ: String): Nothing = - deserializationError(s"Expected $typ value, got $newValue") - - if (maxLevels.exists(_ < 0)) oldValue - else { - (oldValue, newValue) match { - case (_: JsString, newString @ (JsString(_) | JsNull)) => newString - case (_: JsString, _) => mergeError("string") - case (_: JsNumber, newNumber @ (JsNumber(_) | JsNull)) => newNumber - case (_: JsNumber, _) => mergeError("number") - case (_: JsBoolean, newBool @ (JsBoolean(_) | JsNull)) => newBool - case (_: JsBoolean, _) => mergeError("boolean") - case (_: JsArray, newArr @ (JsArray(_) | JsNull)) => newArr - case (_: JsArray, _) => mergeError("array") - case (oldObj: JsObject, newObj: JsObject) => mergeObjects(oldObj, newObj) - case (_: JsObject, JsNull) => JsNull - case (_: JsObject, _) => mergeError("object") - case (JsNull, _) => newValue - } - } - } - - def mergePatch[T](patchable: Patchable[T], retrieve: => Future[Option[T]]): Directive1[T] = - Directive { inner => requestCtx => - onSuccess(retrieve)({ - case Some(oldT) => - Try(patchable.applyTo(oldT)) - .transform[Route]( - mergedT => scala.util.Success(inner(Tuple1(mergedT))), { - case jsonException: DeserializationException => - Success(reject(Rejections.malformedRequestContent(jsonException.getMessage, jsonException))) - case t => Failure(t) - } - ) - .get // intentionally re-throw all other errors - case None => reject() - })(requestCtx) - } -} - -object PatchDirectives extends PatchDirectives diff --git a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala b/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala deleted file mode 100644 index 2854257..0000000 --- a/src/main/scala/xyz/driver/core/rest/PooledHttpClient.scala +++ /dev/null @@ -1,67 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse, Uri} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.scaladsl.{Keep, Sink, Source} -import akka.stream.{ActorMaterializer, OverflowStrategy, QueueOfferResult, ThrottleMode} -import xyz.driver.core.Name - -import scala.concurrent.{ExecutionContext, Future, Promise} -import scala.concurrent.duration._ -import scala.util.{Failure, Success} - -class PooledHttpClient( - baseUri: Uri, - applicationName: Name[App], - applicationVersion: String, - requestRateLimit: Int = 64, - requestQueueSize: Int = 1024)(implicit actorSystem: ActorSystem, executionContext: ExecutionContext) - extends HttpClient { - - private val host = baseUri.authority.host.toString() - private val port = baseUri.effectivePort - private val scheme = baseUri.scheme - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - private val pool = if (scheme.equalsIgnoreCase("https")) { - Http().cachedHostConnectionPoolHttps[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } else { - Http().cachedHostConnectionPool[Promise[HttpResponse]](host, port, settings = connectionPoolSettings) - } - - private val queue = Source - .queue[(HttpRequest, Promise[HttpResponse])](requestQueueSize, OverflowStrategy.dropNew) - .via(pool) - .throttle(requestRateLimit, 1.second, maximumBurst = requestRateLimit, ThrottleMode.shaping) - .toMat(Sink.foreach({ - case ((Success(resp), p)) => p.success(resp) - case ((Failure(e), p)) => p.failure(e) - }))(Keep.left) - .run - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - val responsePromise = Promise[HttpResponse]() - - queue.offer(request -> responsePromise).flatMap { - case QueueOfferResult.Enqueued => - responsePromise.future - case QueueOfferResult.Dropped => - Future.failed(new Exception(s"Request queue to the host $host is overflown")) - case QueueOfferResult.Failure(ex) => - Future.failed(ex) - case QueueOfferResult.QueueClosed => - Future.failed(new Exception("Queue was closed (pool shut down) while running the request")) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala b/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala deleted file mode 100644 index c0e9f99..0000000 --- a/src/main/scala/xyz/driver/core/rest/ProxyRoute.scala +++ /dev/null @@ -1,26 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.server.{RequestContext, Route, RouteResult} -import com.typesafe.config.Config -import xyz.driver.core.Name - -import scala.concurrent.ExecutionContext - -trait ProxyRoute extends DriverRoute { - implicit val executionContext: ExecutionContext - val config: Config - val httpClient: HttpClient - - protected def proxyToService(serviceName: Name[Service]): Route = { ctx: RequestContext => - val httpScheme = config.getString(s"services.${serviceName.value}.httpScheme") - val baseUrl = config.getString(s"services.${serviceName.value}.baseUrl") - - val originalUri = ctx.request.uri - val originalRequest = ctx.request - - val newUri = originalUri.withScheme(httpScheme).withHost(baseUrl) - val newRequest = originalRequest.withUri(newUri) - - httpClient.makeRequest(newRequest).map(RouteResult.Complete) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/RestService.scala b/src/main/scala/xyz/driver/core/rest/RestService.scala deleted file mode 100644 index 09d98b8..0000000 --- a/src/main/scala/xyz/driver/core/rest/RestService.scala +++ /dev/null @@ -1,86 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model._ -import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} -import akka.stream.Materializer - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.{ListT, OptionT} - -trait RestService extends Service { - - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json._ - - protected implicit val exec: ExecutionContext - protected implicit val materializer: Materializer - - implicit class ResponseEntityFoldable(entity: Unmarshal[ResponseEntity]) { - def fold[T](default: => T)(implicit um: Unmarshaller[ResponseEntity, T]): Future[T] = - if (entity.value.isKnownEmpty()) Future.successful[T](default) else entity.to[T] - } - - protected def unitResponse(request: Future[Unmarshal[ResponseEntity]]): OptionT[Future, Unit] = - OptionT[Future, Unit](request.flatMap(_.to[String]).map(_ => Option(()))) - - protected def optionalResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, Option[T]]): OptionT[Future, T] = - OptionT[Future, T](request.flatMap(_.fold(Option.empty[T]))) - - protected def listResponse[T](request: Future[Unmarshal[ResponseEntity]])( - implicit um: Unmarshaller[ResponseEntity, List[T]]): ListT[Future, T] = - ListT[Future, T](request.flatMap(_.fold(List.empty[T]))) - - protected def jsonEntity(json: JsValue): RequestEntity = - HttpEntity(ContentTypes.`application/json`, json.compactPrint) - - protected def mergePatchJsonEntity(json: JsValue): RequestEntity = - HttpEntity(PatchDirectives.`application/merge-patch+json`, json.compactPrint) - - protected def get(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.GET, endpointUri(baseUri, path, query)) - - protected def post(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = httpEntity) - - protected def postJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.POST, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def put(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = httpEntity) - - protected def putJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PUT, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def patch(baseUri: Uri, path: String, httpEntity: RequestEntity) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = httpEntity) - - protected def patchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = jsonEntity(json)) - - protected def mergePatchJson(baseUri: Uri, path: String, json: JsValue) = - HttpRequest(HttpMethods.PATCH, endpointUri(baseUri, path), entity = mergePatchJsonEntity(json)) - - protected def delete(baseUri: Uri, path: String, query: Seq[(String, String)] = Seq.empty) = - HttpRequest(HttpMethods.DELETE, endpointUri(baseUri, path, query)) - - protected def endpointUri(baseUri: Uri, path: String): Uri = - baseUri.withPath(Uri.Path(path)) - - protected def endpointUri(baseUri: Uri, path: String, query: Seq[(String, String)]): Uri = - baseUri.withPath(Uri.Path(path)).withQuery(Uri.Query(query: _*)) - - protected def responseToListResponse[T: JsonFormat](pagination: Option[Pagination])( - response: HttpResponse): Future[ListResponse[T]] = { - import DefaultJsonProtocol._ - val resourceCount = response.headers - .find(_.name() equalsIgnoreCase ContextHeaders.ResourceCount) - .map(_.value().toInt) - .getOrElse(0) - val meta = ListResponse.Meta(resourceCount, pagination.getOrElse(Pagination(resourceCount max 1, 1))) - Unmarshal(response.entity).to[List[T]].map(ListResponse(_, meta)) - } - - protected def responseToListResponse[T: JsonFormat](pagination: Pagination)( - response: HttpResponse): Future[ListResponse[T]] = responseToListResponse(Some(pagination))(response) -} diff --git a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala b/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala deleted file mode 100644 index 646fae8..0000000 --- a/src/main/scala/xyz/driver/core/rest/ServiceDescriptor.scala +++ /dev/null @@ -1,16 +0,0 @@ -package xyz.driver.core -package rest -import scala.annotation.implicitNotFound - -@implicitNotFound( - "Don't know how to communicate with service ${S}. Make sure an implicit ServiceDescriptor is" + - "available. A good place to put one is in the service's companion object.") -trait ServiceDescriptor[S] { - - /** The service's name. Must be unique among all services. */ - def name: String - - /** Get an instance of the service. */ - def connect(transport: HttpRestServiceTransport, url: String): S - -} diff --git a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala b/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala deleted file mode 100644 index 964a5a2..0000000 --- a/src/main/scala/xyz/driver/core/rest/SingleRequestHttpClient.scala +++ /dev/null @@ -1,29 +0,0 @@ -package xyz.driver.core.rest - -import akka.actor.ActorSystem -import akka.http.scaladsl.Http -import akka.http.scaladsl.model.headers.`User-Agent` -import akka.http.scaladsl.model.{HttpRequest, HttpResponse} -import akka.http.scaladsl.settings.{ClientConnectionSettings, ConnectionPoolSettings} -import akka.stream.ActorMaterializer -import xyz.driver.core.Name - -import scala.concurrent.Future - -class SingleRequestHttpClient(applicationName: Name[App], applicationVersion: String, actorSystem: ActorSystem) - extends HttpClient { - - protected implicit val materializer: ActorMaterializer = ActorMaterializer()(actorSystem) - private val client = Http()(actorSystem) - - private val clientConnectionSettings: ClientConnectionSettings = - ClientConnectionSettings(actorSystem).withUserAgentHeader( - Option(`User-Agent`(applicationName.value + "/" + applicationVersion))) - - private val connectionPoolSettings: ConnectionPoolSettings = ConnectionPoolSettings(actorSystem) - .withConnectionSettings(clientConnectionSettings) - - def makeRequest(request: HttpRequest): Future[HttpResponse] = { - client.singleRequest(request, settings = connectionPoolSettings) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/Swagger.scala b/src/main/scala/xyz/driver/core/rest/Swagger.scala deleted file mode 100644 index 5ceac54..0000000 --- a/src/main/scala/xyz/driver/core/rest/Swagger.scala +++ /dev/null @@ -1,144 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.{ContentType, ContentTypes, HttpEntity} -import akka.http.scaladsl.server.Route -import akka.http.scaladsl.server.directives.FileAndResourceDirectives.ResourceFile -import akka.stream.ActorAttributes -import akka.stream.scaladsl.{Framing, StreamConverters} -import akka.util.ByteString -import com.github.swagger.akka.SwaggerHttpService -import com.github.swagger.akka.model._ -import com.typesafe.config.Config -import com.typesafe.scalalogging.Logger -import io.swagger.models.Scheme -import io.swagger.models.auth.{ApiKeyAuthDefinition, In} -import io.swagger.util.Json - -import scala.util.control.NonFatal - -class Swagger( - override val host: String, - accessSchemes: List[String], - version: String, - override val apiClasses: Set[Class[_]], - val config: Config, - val logger: Logger) - extends SwaggerHttpService { - - override val schemes = accessSchemes.map { s => - Scheme.forValue(s) - } - - // Note that the reason for overriding this is a subtle chain of causality: - // - // 1. Some of our endpoints require a single trailing slash and will not - // function if it is omitted - // 2. Swagger omits trailing slashes in its generated api doc - // 3. To work around that, a space is added after the trailing slash in the - // swagger Path annotations - // 4. This space is removed manually in the code below - // - // TODO: Ideally we'd like to drop this custom override and fix the issue in - // 1, by dropping the slash requirement and accepting api endpoints with and - // without trailing slashes. This will require inspecting and potentially - // fixing all service endpoints. - override def generateSwaggerJson: String = { - import io.swagger.models.{Swagger => JSwagger} - - import scala.collection.JavaConverters._ - try { - val swagger: JSwagger = reader.read(apiClasses.asJava) - - val paths = if (swagger.getPaths == null) { - Map.empty - } else { - swagger.getPaths.asScala - } - - // Removing trailing spaces - val fixedPaths = paths.map { - case (key, path) => - key.trim -> path - } - - swagger.setPaths(fixedPaths.asJava) - - Json.pretty().writeValueAsString(swagger) - } catch { - case NonFatal(t) => - logger.error("Issue with creating swagger.json", t) - throw t - } - } - - override val securitySchemeDefinitions = Map( - "token" -> { - val definition = new ApiKeyAuthDefinition("Authorization", In.HEADER) - definition.setDescription("Authentication token") - definition - } - ) - - override val basePath: String = config.getString("swagger.basePath") - override val apiDocsPath: String = config.getString("swagger.docsPath") - - override val info = Info( - config.getString("swagger.apiInfo.description"), - version, - config.getString("swagger.apiInfo.title"), - config.getString("swagger.apiInfo.termsOfServiceUrl"), - contact = Some( - Contact( - config.getString("swagger.apiInfo.contact.name"), - config.getString("swagger.apiInfo.contact.url"), - config.getString("swagger.apiInfo.contact.email") - )), - license = Some( - License( - config.getString("swagger.apiInfo.license"), - config.getString("swagger.apiInfo.licenseUrl") - )), - vendorExtensions = Map.empty[String, AnyRef] - ) - - /** A very simple templating extractor. Gets a resource from the classpath and subsitutes any `{{key}}` with a value. */ - private def getTemplatedResource( - resourceName: String, - contentType: ContentType, - substitution: (String, String)): Route = get { - Option(this.getClass.getClassLoader.getResource(resourceName)) flatMap ResourceFile.apply match { - case Some(ResourceFile(url, length @ _, _)) => - extractSettings { settings => - val stream = StreamConverters - .fromInputStream(() => url.openStream()) - .withAttributes(ActorAttributes.dispatcher(settings.fileIODispatcher)) - .via(Framing.delimiter(ByteString("\n"), 4096, true).map(_.utf8String)) - .map { line => - line.replaceAll(s"\\{\\{${substitution._1}\\}\\}", substitution._2) - } - .map(line => ByteString(line + "\n")) - complete( - HttpEntity(contentType, stream) - ) - } - case None => reject - } - } - - def swaggerUI: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui") - - def swaggerUINew: Route = - pathEndOrSingleSlash { - getTemplatedResource( - "swagger-ui-dist/index.html", - ContentTypes.`text/html(UTF-8)`, - "title" -> config.getString("swagger.apiInfo.title")) - } ~ getFromResourceDirectory("swagger-ui-dist") - -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala deleted file mode 100644 index 5007774..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AlwaysAllowAuthorization.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -class AlwaysAllowAuthorization[U <: User] extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val permissionsMap = permissions.map(_ -> true).toMap - Future.successful(AuthorizationResult(authorized = permissionsMap, ctx.permissionsToken)) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala deleted file mode 100644 index e1a94e1..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthProvider.scala +++ /dev/null @@ -1,75 +0,0 @@ -package xyz.driver.core.rest.auth - -import akka.http.scaladsl.server.directives.Credentials -import com.typesafe.scalalogging.Logger -import scalaz.OptionT -import xyz.driver.core.auth.{AuthToken, Permission, User} -import xyz.driver.core.rest.errors.{ExternalServiceException, UnauthorizedException} -import xyz.driver.core.rest.{AuthorizedServiceRequestContext, ContextHeaders, ServiceRequestContext, serviceContext} - -import scala.concurrent.{ExecutionContext, Future} - -abstract class AuthProvider[U <: User]( - val authorization: Authorization[U], - log: Logger, - val realm: String -)(implicit execution: ExecutionContext) { - - import akka.http.scaladsl.server._ - import Directives.{authorize => akkaAuthorize, _} - - def this(authorization: Authorization[U], log: Logger)(implicit executionContext: ExecutionContext) = - this(authorization, log, "driver.xyz") - - /** - * Specific implementation on how to extract user from request context, - * can either need to do a network call to auth server or extract everything from self-contained token - * - * @param ctx set of request values which can be relevant to authenticate user - * @return authenticated user - */ - def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, U] - - protected def authenticator(context: ServiceRequestContext): AsyncAuthenticator[U] = { - case Credentials.Missing => - log.info(s"Request (${context.trackingId}) missing authentication credentials") - Future.successful(None) - case Credentials.Provided(authToken) => - authenticatedUser(context.withAuthToken(AuthToken(authToken))).run.recover({ - case ExternalServiceException(_, _, Some(UnauthorizedException(_))) => None - }) - } - - /** - * Verifies that a user agent is properly authenticated, and (optionally) authorized with the specified permissions - */ - def authorize( - context: ServiceRequestContext, - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authenticateOAuth2Async[U](realm, authenticator(context)) flatMap { authenticatedUser => - val authCtx = context.withAuthenticatedUser(context.authToken.get, authenticatedUser) - onSuccess(authorization.userHasPermissions(authenticatedUser, permissions)(authCtx)) flatMap { - case AuthorizationResult(authorized, token) => - val allAuthorized = permissions.forall(authorized.getOrElse(_, false)) - akkaAuthorize(allAuthorized) tflatMap { _ => - val cachedPermissionsCtx = token.fold(authCtx)(authCtx.withPermissionsToken) - provide(cachedPermissionsCtx) - } - } - } - } - - /** - * Verifies if request is authenticated and authorized to have `permissions` - */ - def authorize(permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - serviceContext flatMap (authorize(_, permissions: _*)) - } -} - -object AuthProvider { - val AuthenticationTokenHeader: String = ContextHeaders.AuthenticationTokenHeader - val PermissionsTokenHeader: String = ContextHeaders.PermissionsTokenHeader - val SetAuthenticationTokenHeader: String = "set-authorization" - val SetPermissionsTokenHeader: String = "set-permissions" -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala b/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala deleted file mode 100644 index 1a5e9be..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/Authorization.scala +++ /dev/null @@ -1,11 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future - -trait Authorization[U <: User] { - def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala b/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala deleted file mode 100644 index efe28c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/AuthorizationResult.scala +++ /dev/null @@ -1,22 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, PermissionsToken} - -import scalaz.Scalaz.mapMonoid -import scalaz.Semigroup -import scalaz.syntax.semigroup._ - -final case class AuthorizationResult(authorized: Map[Permission, Boolean], token: Option[PermissionsToken]) -object AuthorizationResult { - val unauthorized: AuthorizationResult = AuthorizationResult(authorized = Map.empty, None) - - implicit val authorizationSemigroup: Semigroup[AuthorizationResult] = new Semigroup[AuthorizationResult] { - private implicit val authorizedBooleanSemigroup = Semigroup.instance[Boolean](_ || _) - private implicit val permissionsTokenSemigroup = - Semigroup.instance[Option[PermissionsToken]]((a, b) => b.orElse(a)) - - override def append(a: AuthorizationResult, b: => AuthorizationResult): AuthorizationResult = { - AuthorizationResult(a.authorized |+| b.authorized, a.token |+| b.token) - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala deleted file mode 100644 index 66de4ef..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/CachedTokenAuthorization.scala +++ /dev/null @@ -1,55 +0,0 @@ -package xyz.driver.core.rest.auth - -import java.nio.file.{Files, Path} -import java.security.{KeyFactory, PublicKey} -import java.security.spec.X509EncodedKeySpec - -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future -import scalaz.syntax.std.boolean._ - -class CachedTokenAuthorization[U <: User](publicKey: => PublicKey, issuer: String) extends Authorization[U] { - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - import spray.json._ - - def extractPermissionsFromTokenJSON(tokenObject: JsObject): Option[Map[String, Boolean]] = - tokenObject.fields.get("permissions").collect { - case JsObject(fields) => - fields.collect { - case (key, JsBoolean(value)) => key -> value - } - } - - val result = for { - token <- ctx.permissionsToken - jwt <- Jwt.decode(token.value, publicKey, Seq(JwtAlgorithm.RS256)).toOption - jwtJson = jwt.parseJson.asJsObject - - // Ensure jwt is for the currently authenticated user and the correct issuer, otherwise return None - _ <- jwtJson.fields.get("sub").contains(JsString(user.id.value)).option(()) - _ <- jwtJson.fields.get("iss").contains(JsString(issuer)).option(()) - - permissionsMap <- extractPermissionsFromTokenJSON(jwtJson) - - authorized = permissions.map(p => p -> permissionsMap.getOrElse(p.toString, false)).toMap - } yield AuthorizationResult(authorized, Some(token)) - - Future.successful(result.getOrElse(AuthorizationResult.unauthorized)) - } -} - -object CachedTokenAuthorization { - def apply[U <: User](publicKeyFile: Path, issuer: String): CachedTokenAuthorization[U] = { - lazy val publicKey: PublicKey = { - val publicKeyBase64Encoded = new String(Files.readAllBytes(publicKeyFile)).trim - val publicKeyBase64Decoded = java.util.Base64.getDecoder.decode(publicKeyBase64Encoded) - val spec = new X509EncodedKeySpec(publicKeyBase64Decoded) - KeyFactory.getInstance("RSA").generatePublic(spec) - } - new CachedTokenAuthorization[U](publicKey, issuer) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala b/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala deleted file mode 100644 index 131e7fc..0000000 --- a/src/main/scala/xyz/driver/core/rest/auth/ChainedAuthorization.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.auth - -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.{ExecutionContext, Future} -import scalaz.Scalaz.{futureInstance, listInstance} -import scalaz.syntax.semigroup._ -import scalaz.syntax.traverse._ - -class ChainedAuthorization[U <: User](authorizations: Authorization[U]*)(implicit execution: ExecutionContext) - extends Authorization[U] { - - override def userHasPermissions(user: U, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - def allAuthorized(permissionsMap: Map[Permission, Boolean]): Boolean = - permissions.forall(permissionsMap.getOrElse(_, false)) - - authorizations.toList.foldLeftM[Future, AuthorizationResult](AuthorizationResult.unauthorized) { - (authResult, authorization) => - if (allAuthorized(authResult.authorized)) Future.successful(authResult) - else { - authorization.userHasPermissions(user, permissions).map(authResult |+| _) - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala deleted file mode 100644 index ff3424d..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/AuthDirectives.scala +++ /dev/null @@ -1,19 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.server.{Directive1, Directives => AkkaDirectives} -import xyz.driver.core.auth.{Permission, User} -import xyz.driver.core.rest.auth.AuthProvider - -/** Authentication and authorization directives. */ -trait AuthDirectives extends AkkaDirectives { - - /** Authenticate a user based on service request headers and check if they have all given permissions. */ - def authenticateAndAuthorize[U <: User]( - authProvider: AuthProvider[U], - permissions: Permission*): Directive1[AuthorizedServiceRequestContext[U]] = { - authProvider.authorize(permissions: _*) - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala b/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala deleted file mode 100644 index 5a6bbfd..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/CorsDirectives.scala +++ /dev/null @@ -1,72 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import akka.http.scaladsl.model.HttpMethods._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{HttpResponse, StatusCodes} -import akka.http.scaladsl.server.{Route, Directives => AkkaDirectives} - -/** Directives to handle Cross-Origin Resource Sharing (CORS). */ -trait CorsDirectives extends AkkaDirectives { - - /** Route handler that injects Cross-Origin Resource Sharing (CORS) headers depending on the request - * origin. - * - * In a microservice environment, it can be difficult to know in advance the exact origin - * from which requests may be issued [1]. For example, the request may come from a web page served from - * any of the services, on any namespace or from other documentation sites. In general, only a set - * of domain suffixes can be assumed to be known in advance. Unfortunately however, browsers that - * implement CORS require exact specification of allowed origins, including full host name and scheme, - * in order to send credentials and headers with requests to other origins. - * - * This route wrapper provides a simple way alleviate CORS' exact allowed-origin requirement by - * dynamically echoing the origin as an allowed origin if and only if its domain is whitelisted. - * - * Note that the simplicity of this implementation comes with two notable drawbacks: - * - * - All OPTION requests are "hijacked" and will not be passed to the inner route of this wrapper. - * - * - Allowed methods and headers can not be customized on a per-request basis. All standard - * HTTP methods are allowed, and allowed headers are specified for all inner routes. - * - * This handler is not suited for cases where more fine-grained control of responses is required. - * - * [1] Assuming browsers communicate directly with the services and that requests aren't proxied through - * a common gateway. - * - * @param allowedSuffixes The set of domain suffixes (e.g. internal.example.org, example.org) of allowed - * origins. - * @param allowedHeaders Header names that will be set in `Access-Control-Allow-Headers`. - * @param inner Route into which CORS headers will be injected. - */ - def cors(allowedSuffixes: Set[String], allowedHeaders: Seq[String])(inner: Route): Route = { - optionalHeaderValueByType[Origin](()) { maybeOrigin => - val allowedOrigins: HttpOriginRange = maybeOrigin match { - // Note that this is not a security issue: the client will never send credentials if the allowed - // origin is set to *. This case allows us to deal with clients that do not send an origin header. - case None => HttpOriginRange.* - case Some(requestOrigin) => - val allowedOrigin = requestOrigin.origins.find(origin => - allowedSuffixes.exists(allowed => origin.host.host.address endsWith allowed)) - allowedOrigin.map(HttpOriginRange(_)).getOrElse(HttpOriginRange.*) - } - - respondWithHeaders( - `Access-Control-Allow-Origin`.forRange(allowedOrigins), - `Access-Control-Allow-Credentials`(true), - `Access-Control-Allow-Headers`(allowedHeaders: _*), - `Access-Control-Expose-Headers`(allowedHeaders: _*) - ) { - options { // options is used during preflight check - complete( - HttpResponse(StatusCodes.OK) - .withHeaders(`Access-Control-Allow-Methods`(OPTIONS, POST, PUT, GET, DELETE, PATCH, TRACE))) - } ~ inner // in case of non-preflight check we don't do anything special - } - } - } - -} - -object CorsDirectives extends CorsDirectives diff --git a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala b/src/main/scala/xyz/driver/core/rest/directives/Directives.scala deleted file mode 100644 index 0cd4ef1..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Directives.scala +++ /dev/null @@ -1,6 +0,0 @@ -package xyz.driver.core -package rest -package directives - -trait Directives extends AuthDirectives with CorsDirectives with PathMatchers with Unmarshallers -object Directives extends Directives diff --git a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala b/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala deleted file mode 100644 index 218c9ae..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/PathMatchers.scala +++ /dev/null @@ -1,85 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.time.Instant -import java.util.UUID - -import akka.http.scaladsl.model.Uri.Path -import akka.http.scaladsl.server.PathMatcher.{Matched, Unmatched} -import akka.http.scaladsl.server.{PathMatcher, PathMatcher1, PathMatchers => AkkaPathMatchers} -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.refineV -import xyz.driver.core.domain.PhoneNumber -import xyz.driver.core.time.Time - -import scala.util.control.NonFatal - -/** Akka-HTTP path matchers for custom core types. */ -trait PathMatchers { - - private def UuidInPath[T]: PathMatcher1[Id[T]] = - AkkaPathMatchers.JavaUUID.map((id: UUID) => Id[T](id.toString.toLowerCase)) - - def IdInPath[T]: PathMatcher1[Id[T]] = UuidInPath[T] | new PathMatcher1[Id[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Id[T](segment))) - case _ => Unmatched - } - } - - def NameInPath[T]: PathMatcher1[Name[T]] = new PathMatcher1[Name[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => Matched(tail, Tuple1(Name[T](segment))) - case _ => Unmatched - } - } - - private def timestampInPath: PathMatcher1[Long] = - PathMatcher("""[+-]?\d*""".r) flatMap { string => - try Some(string.toLong) - catch { case _: IllegalArgumentException => None } - } - - def InstantInPath: PathMatcher1[Instant] = - new PathMatcher1[Instant] { - def apply(path: Path): PathMatcher.Matching[Tuple1[Instant]] = path match { - case Path.Segment(head, tail) => - try Matched(tail, Tuple1(Instant.parse(head))) - catch { - case NonFatal(_) => Unmatched - } - case _ => Unmatched - } - } | timestampInPath.map(Instant.ofEpochMilli) - - def TimeInPath: PathMatcher1[Time] = InstantInPath.map(instant => Time(instant.toEpochMilli)) - - def NonEmptyNameInPath[T]: PathMatcher1[NonEmptyName[T]] = new PathMatcher1[NonEmptyName[T]] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - refineV[NonEmpty](segment) match { - case Left(_) => Unmatched - case Right(nonEmptyString) => Matched(tail, Tuple1(NonEmptyName[T](nonEmptyString))) - } - case _ => Unmatched - } - } - - def RevisionInPath[T]: PathMatcher1[Revision[T]] = - PathMatcher("""[\da-fA-F]{8}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{4}-[\da-fA-F]{12}""".r) flatMap { string => - Some(Revision[T](string)) - } - - def PhoneInPath: PathMatcher1[PhoneNumber] = new PathMatcher1[PhoneNumber] { - def apply(path: Path) = path match { - case Path.Segment(segment, tail) => - PhoneNumber - .parse(segment) - .map(parsed => Matched(tail, Tuple1(parsed))) - .getOrElse(Unmatched) - case _ => Unmatched - } - } - -} diff --git a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala b/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala deleted file mode 100644 index 6c45d15..0000000 --- a/src/main/scala/xyz/driver/core/rest/directives/Unmarshallers.scala +++ /dev/null @@ -1,40 +0,0 @@ -package xyz.driver.core -package rest -package directives - -import java.util.UUID - -import akka.http.scaladsl.marshalling.{Marshaller, Marshalling} -import akka.http.scaladsl.unmarshalling.Unmarshaller -import spray.json.{JsString, JsValue, JsonParser, JsonReader, JsonWriter} - -/** Akka-HTTP unmarshallers for custom core types. */ -trait Unmarshallers { - - implicit def idUnmarshaller[A]: Unmarshaller[String, Id[A]] = - Unmarshaller.strict[String, Id[A]] { str => - Id[A](UUID.fromString(str).toString) - } - - implicit def paramUnmarshaller[T](implicit reader: JsonReader[T]): Unmarshaller[String, T] = - Unmarshaller.firstOf( - Unmarshaller.strict((JsString(_: String)) andThen reader.read), - stringToValueUnmarshaller[T] - ) - - implicit def revisionFromStringUnmarshaller[T]: Unmarshaller[String, Revision[T]] = - Unmarshaller.strict[String, Revision[T]](Revision[T]) - - val jsValueToStringMarshaller: Marshaller[JsValue, String] = - Marshaller.strict[JsValue, String](value => Marshalling.Opaque[String](() => value.compactPrint)) - - def valueToStringMarshaller[T](implicit jsonFormat: JsonWriter[T]): Marshaller[T, String] = - jsValueToStringMarshaller.compose[T](jsonFormat.write) - - val stringToJsValueUnmarshaller: Unmarshaller[String, JsValue] = - Unmarshaller.strict[String, JsValue](value => JsonParser(value)) - - def stringToValueUnmarshaller[T](implicit jsonFormat: JsonReader[T]): Unmarshaller[String, T] = - stringToJsValueUnmarshaller.map[T](jsonFormat.read) - -} diff --git a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala b/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala deleted file mode 100644 index f2962c9..0000000 --- a/src/main/scala/xyz/driver/core/rest/errors/serviceException.scala +++ /dev/null @@ -1,27 +0,0 @@ -package xyz.driver.core.rest.errors - -sealed abstract class ServiceException(val message: String) extends Exception(message) - -final case class InvalidInputException(override val message: String = "Invalid input") extends ServiceException(message) - -final case class InvalidActionException(override val message: String = "This action is not allowed") - extends ServiceException(message) - -final case class UnauthorizedException( - override val message: String = "The user's authentication credentials are invalid or missing") - extends ServiceException(message) - -final case class ResourceNotFoundException(override val message: String = "Resource not found") - extends ServiceException(message) - -final case class ExternalServiceException( - serviceName: String, - serviceMessage: String, - serviceException: Option[ServiceException]) - extends ServiceException(s"Error while calling '$serviceName': $serviceMessage") - -final case class ExternalServiceTimeoutException(serviceName: String) - extends ServiceException(s"$serviceName took too long to respond") - -final case class DatabaseException(override val message: String = "Database access error") - extends ServiceException(message) diff --git a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala b/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala deleted file mode 100644 index 866476d..0000000 --- a/src/main/scala/xyz/driver/core/rest/headers/Traceparent.scala +++ /dev/null @@ -1,33 +0,0 @@ -package xyz.driver.core -package rest -package headers - -import akka.http.scaladsl.model.headers.{ModeledCustomHeader, ModeledCustomHeaderCompanion} -import xyz.driver.core.reporting.SpanContext - -import scala.util.Try - -/** Encapsulates a trace context in an HTTP header for propagation across services. - * - * This implementation corresponds to the W3C editor's draft specification (as of 2018-08-28) - * https://w3c.github.io/distributed-tracing/report-trace-context.html. The 'flags' field is - * ignored. - */ -final case class Traceparent(spanContext: SpanContext) extends ModeledCustomHeader[Traceparent] { - override def renderInRequests = true - override def renderInResponses = true - override val companion: Traceparent.type = Traceparent - override def value: String = f"01-${spanContext.traceId}-${spanContext.spanId}-00" -} -object Traceparent extends ModeledCustomHeaderCompanion[Traceparent] { - override val name = "traceparent" - override def parse(value: String) = Try { - val Array(version, traceId, spanId, _) = value.split("-") - require( - version == "01", - s"Found unsupported version '$version' in traceparent header. Only version '01' is supported.") - new Traceparent( - new SpanContext(traceId, spanId) - ) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala deleted file mode 100644 index 34a4a9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ /dev/null @@ -1,323 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.Flow -import akka.util.ByteString -import scalaz.Scalaz.{intInstance, stringInstance} -import scalaz.syntax.equal._ -import scalaz.{Functor, OptionT} -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.errors.ExternalServiceException -import xyz.driver.core.rest.headers.Traceparent -import xyz.driver.tracing.TracingDirectives - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.Try - -trait Service - -object Service - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] -} - -sealed trait SortingOrder -object SortingOrder { - case object Asc extends SortingOrder - case object Desc extends SortingOrder -} - -final case class SortingField(name: String, sortingOrder: SortingOrder) -final case class Sorting(sortingFields: Seq[SortingField]) - -final case class Pagination(pageSize: Int, pageNumber: Int) { - require(pageSize > 0, "Page size must be greater than zero") - require(pageNumber > 0, "Page number must be greater than zero") - - def offset: Int = pageSize * (pageNumber - 1) -} - -final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) - -object ListResponse { - - def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = - ListResponse( - items = items, - meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) - - final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) - - object Meta { - def apply(itemsCount: Int, pagination: Pagination): Meta = - Meta(itemsCount, pagination.pageNumber, pagination.pageSize) - } - -} - -object `package` { - - implicit class FutureExtensions[T](future: Future[T]) { - def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = - future.transform(identity, { - case ExternalServiceException(_, _, Some(e)) => e - case t: Throwable => t - }) - } - - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } - - object ContextHeaders { - val AuthenticationTokenHeader: String = "Authorization" - val PermissionsTokenHeader: String = "Permissions" - val AuthenticationHeaderPrefix: String = "Bearer" - val ClientFingerprintHeader: String = "X-Client-Fingerprint" - val TrackingIdHeader: String = "X-Trace" - val StacktraceHeader: String = "X-Stacktrace" - val OriginatingIpHeader: String = "X-Forwarded-For" - val ResourceCount: String = "X-Resource-Count" - val PageCount: String = "X-Page-Count" - val TraceHeaderName: String = TracingDirectives.TraceHeaderName - val SpanHeaderName: String = TracingDirectives.SpanHeaderName - } - - val AllowedHeaders: Seq[String] = - Seq( - "Origin", - "X-Requested-With", - "Content-Type", - "Content-Length", - "Accept", - "X-Trace", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Headers", - "Server", - "Date", - ContextHeaders.ClientFingerprintHeader, - ContextHeaders.TrackingIdHeader, - ContextHeaders.TraceHeaderName, - ContextHeaders.SpanHeaderName, - ContextHeaders.StacktraceHeader, - ContextHeaders.AuthenticationTokenHeader, - ContextHeaders.OriginatingIpHeader, - ContextHeaders.ResourceCount, - ContextHeaders.PageCount, - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader, - "Traceparent" - ) - - def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - - def serviceContext: Directive1[ServiceRequestContext] = { - def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { - headers.map({ header => - if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( - ContextHeaders.AuthenticationHeaderPrefix)) { - Authorization(OAuth2BearerToken(header.value)) - } else header - })(collection.breakOut) - } - extractClientIP flatMap { remoteAddress => - mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => - extract(ctx => extractServiceContext(ctx.request, remoteAddress)) - } - } - } - - def respondWithCorsAllowedHeaders: Directive0 = { - respondWithHeaders( - List[HttpHeader]( - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - )) - } - - def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { - respondWithHeader { - `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) - } - } - - def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { - respondWithHeaders( - List[HttpHeader]( - Allow(methods.to[collection.immutable.Seq]), - `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) - )) - } - - def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = - new ServiceRequestContext( - extractTrackingId(request), - extractOriginatingIP(request, remoteAddress), - extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name === ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractFingerprintHash(request: HttpRequest): Option[String] = { - request.headers - .find(_.name === ContextHeaders.ClientFingerprintHeader) - .map(_.value()) - } - - def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { - request.headers - .find(_.name === ContextHeaders.OriginatingIpHeader) - .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) - .orElse(remoteAddress.toOption) - } - - def extractStacktrace(request: HttpRequest): Array[String] = - request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers - .filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || - h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || - h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || - h.name === Traceparent.name - } - .map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } - .toMap - } - - private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { - @annotation.tailrec - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse - else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) - } - } - } - - val indices = dirtyIndices(0, Nil) - - indices.headOption.fold(byteString) { head => - val builder = ByteString.newBuilder - builder ++= byteString.take(head) - - (indices :+ byteString.length).sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - case Seq(_) => // Should not match; sliding on at least 2 elements - assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") - } - builder.result - } - } - - val sanitizeRequestEntity: Directive0 = { - mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) - } - - val paginated: Directive1[Pagination] = - parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) - - private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = - (pageSizeOpt, pageNumberOpt) match { - case (Some(size), Some(number)) => Option(Pagination(size, number)) - case (None, None) => Option.empty[Pagination] - case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") - } - - val optionalPagination: Directive1[Option[Pagination]] = - parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) - - def paginationQuery(pagination: Pagination) = - Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) - - def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( - implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { - optionalPagination { pagination => - onSuccess(handler(pagination)) { - case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => - val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) - val headers = List( - RawHeader(ContextHeaders.ResourceCount, count.toString), - RawHeader(ContextHeaders.PageCount, pageCount.toString)) - - respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) - } - } - } - - private def extractSorting(sortingString: Option[String]): Sorting = { - val sortingFields = sortingString.fold(Seq.empty[SortingField])( - _.split(",") - .filter(_.length > 0) - .map { sortingParam => - if (sortingParam.startsWith("-")) { - SortingField(sortingParam.substring(1), SortingOrder.Desc) - } else { - val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam - SortingField(fieldName, SortingOrder.Asc) - } - } - .toSeq) - - Sorting(sortingFields) - } - - val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) - - def sortingQuery(sorting: Sorting): Seq[(String, String)] = { - val sortingString = sorting.sortingFields - .map { sortingField => - sortingField.sortingOrder match { - case SortingOrder.Asc => sortingField.name - case SortingOrder.Desc => s"-${sortingField.name}" - } - } - .mkString(",") - Seq("sort" -> sortingString) - } -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala b/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala deleted file mode 100644 index 55f1a2e..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceDiscovery.scala +++ /dev/null @@ -1,24 +0,0 @@ -package xyz.driver.core.rest - -import xyz.driver.core.Name - -trait ServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T -} - -trait SavingUsedServiceDiscovery { - private val usedServices = new scala.collection.mutable.HashSet[String]() - - def saveServiceUsage(serviceName: Name[Service]): Unit = usedServices.synchronized { - usedServices += serviceName.value - } - - def getUsedServices: Set[String] = usedServices.synchronized { usedServices.toSet } -} - -class NoServiceDiscovery extends ServiceDiscovery with SavingUsedServiceDiscovery { - - def discover[T <: Service](serviceName: Name[Service]): T = - throw new IllegalArgumentException(s"Service with name $serviceName is unknown") -} diff --git a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala b/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala deleted file mode 100644 index d2e4bc3..0000000 --- a/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala +++ /dev/null @@ -1,87 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import xyz.driver.core.auth.{AuthToken, PermissionsToken, User} -import xyz.driver.core.generators -import scalaz.Scalaz.{mapEqual, stringInstance} -import scalaz.syntax.equal._ -import xyz.driver.core.reporting.SpanContext -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.headers.Traceparent - -import scala.util.Try - -class ServiceRequestContext( - val trackingId: String = generators.nextUuid().toString, - val originatingIp: Option[InetAddress] = None, - val contextHeaders: Map[String, String] = Map.empty[String, String]) { - def authToken: Option[AuthToken] = - contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply) - - def permissionsToken: Option[PermissionsToken] = - contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply) - - def withAuthToken(authToken: AuthToken): ServiceRequestContext = - new ServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value) - ) - - def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value), - user - ) - - override def hashCode(): Int = - Seq[Any](trackingId, originatingIp, contextHeaders) - .foldLeft(31)((result, obj) => 31 * result + obj.hashCode()) - - override def equals(obj: Any): Boolean = obj match { - case ctx: ServiceRequestContext => - trackingId === ctx.trackingId && - originatingIp == originatingIp && - contextHeaders === ctx.contextHeaders - case _ => false - } - - def spanContext: SpanContext = { - val validHeader = Try { - contextHeaders(Traceparent.name) - }.flatMap { value => - Traceparent.parse(value) - } - validHeader.map(_.spanContext).getOrElse(SpanContext.fresh()) - } - - override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)" -} - -class AuthorizedServiceRequestContext[U <: User]( - override val trackingId: String = generators.nextUuid().toString, - override val originatingIp: Option[InetAddress] = None, - override val contextHeaders: Map[String, String] = Map.empty[String, String], - val authenticatedUser: U) - extends ServiceRequestContext { - - def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] = - new AuthorizedServiceRequestContext[U]( - trackingId, - originatingIp, - contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value), - authenticatedUser) - - override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode() - - override def equals(obj: Any): Boolean = obj match { - case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser - case _ => false - } - - override def toString: String = - s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)" -} diff --git a/src/test/scala/xyz/driver/core/AuthTest.scala b/src/test/scala/xyz/driver/core/AuthTest.scala deleted file mode 100644 index 2e772fb..0000000 --- a/src/test/scala/xyz/driver/core/AuthTest.scala +++ /dev/null @@ -1,165 +0,0 @@ -package xyz.driver.core - -import akka.http.scaladsl.model.headers.{ - HttpChallenges, - OAuth2BearerToken, - RawHeader, - Authorization => AkkaAuthorization -} -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import pdi.jwt.{Jwt, JwtAlgorithm} -import xyz.driver.core.auth._ -import xyz.driver.core.domain.Email -import xyz.driver.core.logging._ -import xyz.driver.core.rest._ -import xyz.driver.core.rest.auth._ -import xyz.driver.core.time.Time - -import scala.concurrent.Future -import scalaz.OptionT - -class AuthTest extends FlatSpec with Matchers with ScalatestRouteTest { - - case object TestRoleAllowedPermission extends Permission - case object TestRoleAllowedByTokenPermission extends Permission - case object TestRoleNotAllowedPermission extends Permission - - val TestRole = Role(Id("1"), Name("testRole")) - - val (publicKey, privateKey) = { - import java.security.KeyPairGenerator - - val keygen = KeyPairGenerator.getInstance("RSA") - keygen.initialize(2048) - - val keyPair = keygen.generateKeyPair() - (keyPair.getPublic, keyPair.getPrivate) - } - - val basicAuthorization: Authorization[User] = new Authorization[User] { - - override def userHasPermissions(user: User, permissions: Seq[Permission])( - implicit ctx: ServiceRequestContext): Future[AuthorizationResult] = { - val authorized = permissions.map(p => p -> (p === TestRoleAllowedPermission)).toMap - Future.successful(AuthorizationResult(authorized, ctx.permissionsToken)) - } - } - - val tokenIssuer = "users" - val tokenAuthorization = new CachedTokenAuthorization[User](publicKey, tokenIssuer) - - val authorization = new ChainedAuthorization[User](tokenAuthorization, basicAuthorization) - - val authStatusService = new AuthProvider[User](authorization, NoLogger) { - override def authenticatedUser(implicit ctx: ServiceRequestContext): OptionT[Future, User] = - OptionT.optionT[Future] { - if (ctx.contextHeaders.keySet.contains(AuthProvider.AuthenticationTokenHeader)) { - Future.successful( - Some( - AuthTokenUserInfo( - Id[User]("1"), - Email("foo", "bar"), - emailVerified = true, - audience = "driver", - roles = Set(TestRole), - expirationTime = Time(1000000L) - ))) - } else { - Future.successful(Option.empty[User]) - } - } - } - - import authStatusService._ - - "'authorize' directive" should "throw error if auth token is not in the request" in { - - Get("/naive/attempt") ~> - authorize(TestRoleAllowedPermission) { user => - complete("Never going to be here") - } ~> - check { - // handled shouldBe false - rejections should contain( - AuthenticationFailedRejection( - AuthenticationFailedRejection.CredentialsMissing, - HttpChallenges.oAuth2(authStatusService.realm))) - } - } - - it should "throw error if authorized user does not have the requested permission" in { - - val referenceAuthToken = AuthToken("I am a test role's token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Post("/administration/attempt").addHeader( - referenceAuthHeader - ) ~> - authorize(TestRoleNotAllowedPermission) { user => - complete("Never going to get here") - } ~> - check { - handled shouldBe false - rejections should contain(AuthorizationFailedRejection) - } - } - - it should "pass and retrieve the token to client code, if token is in request and user has permission" in { - val referenceAuthToken = AuthToken("I am token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Get("/valid/attempt/?a=2&b=5").addHeader( - referenceAuthHeader - ) ~> - authorize(TestRoleAllowedPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized" - } - } - - it should "authenticate correctly even without the 'Bearer' prefix on the Authorization header" in { - val referenceAuthToken = AuthToken("unprefixed_token") - - Get("/valid/attempt/?a=2&b=5").addHeader( - RawHeader(ContextHeaders.AuthenticationTokenHeader, referenceAuthToken.value) - ) ~> - authorize(TestRoleAllowedPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized" - } - } - - it should "authorize permission found in permissions token" in { - import spray.json._ - - val claim = JsObject( - Map( - "iss" -> JsString(tokenIssuer), - "sub" -> JsString("1"), - "permissions" -> JsObject(Map(TestRoleAllowedByTokenPermission.toString -> JsBoolean(true))) - )).prettyPrint - val permissionsToken = PermissionsToken(Jwt.encode(claim, privateKey, JwtAlgorithm.RS256)) - val referenceAuthToken = AuthToken("I am token") - val referenceAuthHeader = AkkaAuthorization(OAuth2BearerToken(referenceAuthToken.value)) - - Get("/alic/attempt/?a=2&b=5") - .addHeader(referenceAuthHeader) - .addHeader(RawHeader(AuthProvider.PermissionsTokenHeader, permissionsToken.value)) ~> - authorize(TestRoleAllowedByTokenPermission) { ctx => - complete(s"Alright, user ${ctx.authenticatedUser.id} is authorized by permissions token") - } ~> - check { - handled shouldBe true - responseAs[String] shouldBe "Alright, user 1 is authorized by permissions token" - } - } -} diff --git a/src/test/scala/xyz/driver/core/GeneratorsTest.scala b/src/test/scala/xyz/driver/core/GeneratorsTest.scala deleted file mode 100644 index 7e740a4..0000000 --- a/src/test/scala/xyz/driver/core/GeneratorsTest.scala +++ /dev/null @@ -1,264 +0,0 @@ -package xyz.driver.core - -import org.scalatest.{Assertions, FlatSpec, Matchers} - -import scala.collection.immutable.IndexedSeq - -class GeneratorsTest extends FlatSpec with Matchers with Assertions { - import generators._ - - "Generators" should "be able to generate com.drivergrp.core.Id identifiers" in { - - val generatedId1 = nextId[String]() - val generatedId2 = nextId[String]() - val generatedId3 = nextId[Long]() - - generatedId1.length should be >= 0 - generatedId2.length should be >= 0 - generatedId3.length should be >= 0 - generatedId1 should not be generatedId2 - generatedId2 should !==(generatedId3) - } - - it should "be able to generate com.drivergrp.core.Id identifiers with max value" in { - - val generatedLimitedId1 = nextId[String](5) - val generatedLimitedId2 = nextId[String](4) - val generatedLimitedId3 = nextId[Long](3) - - generatedLimitedId1.length should be >= 0 - generatedLimitedId1.length should be < 6 - generatedLimitedId2.length should be >= 0 - generatedLimitedId2.length should be < 5 - generatedLimitedId3.length should be >= 0 - generatedLimitedId3.length should be < 4 - generatedLimitedId1 should not be generatedLimitedId2 - generatedLimitedId2 should !==(generatedLimitedId3) - } - - it should "be able to generate com.drivergrp.core.Name names" in { - - Seq.fill(10)(nextName[String]()).distinct.size should be > 1 - nextName[String]().value.length should be >= 0 - - val fixedLengthName = nextName[String](10) - fixedLengthName.length should be <= 10 - assert(!fixedLengthName.value.exists(_.isControl)) - } - - it should "be able to generate com.drivergrp.core.NonEmptyName with non empty strings" in { - - assert(nextNonEmptyName[String]().value.value.nonEmpty) - } - - it should "be able to generate proper UUIDs" in { - - nextUuid() should not be nextUuid() - nextUuid().toString.length should be(36) - } - - it should "be able to generate new Revisions" in { - - nextRevision[String]() should not be nextRevision[String]() - nextRevision[String]().id.length should be > 0 - } - - it should "be able to generate strings" in { - - nextString() should not be nextString() - nextString().length should be >= 0 - - val fixedLengthString = nextString(20) - fixedLengthString.length should be <= 20 - assert(!fixedLengthString.exists(_.isControl)) - } - - it should "be able to generate strings non-empty strings whic are non empty" in { - - assert(nextNonEmptyString().value.nonEmpty) - } - - it should "be able to generate options which are sometimes have values and sometimes not" in { - - val generatedOption = nextOption("2") - - generatedOption should not contain "1" - assert(generatedOption === Some("2") || generatedOption === None) - } - - it should "be able to generate a pair of two generated values" in { - - val constantPair = nextPair("foo", 1L) - constantPair._1 should be("foo") - constantPair._2 should be(1L) - - val generatedPair = nextPair(nextId[Int](), nextName[Int]()) - - generatedPair._1.length should be > 0 - generatedPair._2.length should be > 0 - - nextPair(nextId[Int](), nextName[Int]()) should not be - nextPair(nextId[Int](), nextName[Int]()) - } - - it should "be able to generate a triad of two generated values" in { - - val constantTriad = nextTriad("foo", "bar", 1L) - constantTriad._1 should be("foo") - constantTriad._2 should be("bar") - constantTriad._3 should be(1L) - - val generatedTriad = nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) - - generatedTriad._1.length should be > 0 - generatedTriad._2.length should be > 0 - generatedTriad._3 should be >= BigDecimal(0.00) - - nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) should not be - nextTriad(nextId[Int](), nextName[Int](), nextBigDecimal()) - } - - it should "be able to generate a time value" in { - - val generatedTime = nextTime() - val currentTime = System.currentTimeMillis() - - generatedTime.millis should be >= 0L - generatedTime.millis should be <= currentTime - } - - it should "be able to generate a time range value" in { - - val generatedTimeRange = nextTimeRange() - val currentTime = System.currentTimeMillis() - - generatedTimeRange.start.millis should be >= 0L - generatedTimeRange.start.millis should be <= currentTime - generatedTimeRange.end.millis should be >= 0L - generatedTimeRange.end.millis should be <= currentTime - generatedTimeRange.start.millis should be <= generatedTimeRange.end.millis - } - - it should "be able to generate a BigDecimal value" in { - - val defaultGeneratedBigDecimal = nextBigDecimal() - - defaultGeneratedBigDecimal should be >= BigDecimal(0.00) - defaultGeneratedBigDecimal should be <= BigDecimal(1000000.00) - defaultGeneratedBigDecimal.precision should be(2) - - val unitIntervalBigDecimal = nextBigDecimal(1.00, 8) - - unitIntervalBigDecimal should be >= BigDecimal(0.00) - unitIntervalBigDecimal should be <= BigDecimal(1.00) - unitIntervalBigDecimal.precision should be(8) - } - - it should "be able to generate a specific value from a set of values" in { - - val possibleOptions = Set(1, 3, 5, 123, 0, 9) - - val pick1 = generators.oneOf(possibleOptions) - val pick2 = generators.oneOf(possibleOptions) - val pick3 = generators.oneOf(possibleOptions) - - possibleOptions should contain(pick1) - possibleOptions should contain(pick2) - possibleOptions should contain(pick3) - - val pick4 = generators.oneOf(1, 3, 5, 123, 0, 9) - val pick5 = generators.oneOf(1, 3, 5, 123, 0, 9) - val pick6 = generators.oneOf(1, 3, 5, 123, 0, 9) - - possibleOptions should contain(pick4) - possibleOptions should contain(pick5) - possibleOptions should contain(pick6) - - Set(pick1, pick2, pick3, pick4, pick5, pick6).size should be >= 1 - } - - it should "be able to generate a specific value from an enumeratum enum" in { - - import enumeratum._ - sealed trait TestEnumValue extends EnumEntry - object TestEnum extends Enum[TestEnumValue] { - case object Value1 extends TestEnumValue - case object Value2 extends TestEnumValue - case object Value3 extends TestEnumValue - case object Value4 extends TestEnumValue - val values: IndexedSeq[TestEnumValue] = findValues - } - - val picks = (1 to 100).map(_ => generators.oneOf(TestEnum)) - - TestEnum.values should contain allElementsOf picks - picks.toSet.size should be >= 1 - } - - it should "be able to generate array with values generated by generators" in { - - val arrayOfTimes = arrayOf(nextTime(), 16) - arrayOfTimes.length should be <= 16 - - val arrayOfBigDecimals = arrayOf(nextBigDecimal(), 8) - arrayOfBigDecimals.length should be <= 8 - } - - it should "be able to generate seq with values generated by generators" in { - - val seqOfTimes = seqOf(nextTime(), 16) - seqOfTimes.size should be <= 16 - - val seqOfBigDecimals = seqOf(nextBigDecimal(), 8) - seqOfBigDecimals.size should be <= 8 - } - - it should "be able to generate vector with values generated by generators" in { - - val vectorOfTimes = vectorOf(nextTime(), 16) - vectorOfTimes.size should be <= 16 - - val vectorOfStrings = seqOf(nextString(), 8) - vectorOfStrings.size should be <= 8 - } - - it should "be able to generate list with values generated by generators" in { - - val listOfTimes = listOf(nextTime(), 16) - listOfTimes.size should be <= 16 - - val listOfBigDecimals = seqOf(nextBigDecimal(), 8) - listOfBigDecimals.size should be <= 8 - } - - it should "be able to generate set with values generated by generators" in { - - val setOfTimes = vectorOf(nextTime(), 16) - setOfTimes.size should be <= 16 - - val setOfBigDecimals = seqOf(nextBigDecimal(), 8) - setOfBigDecimals.size should be <= 8 - } - - it should "be able to generate maps with keys and values generated by generators" in { - - val generatedConstantMap = mapOf("key", 123, 10) - generatedConstantMap.size should be <= 1 - assert(generatedConstantMap.keys.forall(_ == "key")) - assert(generatedConstantMap.values.forall(_ == 123)) - - val generatedMap = mapOf(nextString(10), nextBigDecimal(), 10) - assert(generatedMap.keys.forall(_.length <= 10)) - assert(generatedMap.values.forall(_ >= BigDecimal(0.00))) - } - - it should "compose deeply" in { - - val generatedNestedMap = mapOf(nextString(10), nextPair(nextBigDecimal(), nextOption(123)), 10) - - generatedNestedMap.size should be <= 10 - generatedNestedMap.keySet.size should be <= 10 - generatedNestedMap.values.size should be <= 10 - assert(generatedNestedMap.values.forall(value => !value._2.exists(_ != 123))) - } -} diff --git a/src/test/scala/xyz/driver/core/JsonTest.scala b/src/test/scala/xyz/driver/core/JsonTest.scala deleted file mode 100644 index fd693f9..0000000 --- a/src/test/scala/xyz/driver/core/JsonTest.scala +++ /dev/null @@ -1,521 +0,0 @@ -package xyz.driver.core - -import java.net.InetAddress -import java.time.{Instant, LocalDate} - -import akka.http.scaladsl.model.Uri -import akka.http.scaladsl.server.PathMatcher -import akka.http.scaladsl.server.PathMatcher.Matched -import com.neovisionaries.i18n.{CountryCode, CurrencyCode} -import enumeratum._ -import eu.timepit.refined.collection.NonEmpty -import eu.timepit.refined.numeric.Positive -import eu.timepit.refined.refineMV -import org.scalatest.{Inspectors, Matchers, WordSpec} -import spray.json._ -import xyz.driver.core.TestTypes.CustomGADT -import xyz.driver.core.auth.AuthCredentials -import xyz.driver.core.domain.{Email, PhoneNumber} -import xyz.driver.core.json._ -import xyz.driver.core.json.enumeratum.HasJsonFormat -import xyz.driver.core.tagging._ -import xyz.driver.core.time.provider.SystemTimeProvider -import xyz.driver.core.time.{Time, TimeOfDay} - -import scala.collection.immutable.IndexedSeq -import scala.language.postfixOps - -class JsonTest extends WordSpec with Matchers with Inspectors { - import DefaultJsonProtocol._ - - "Json format for Id" should { - "read and write correct JSON" in { - - val referenceId = Id[String]("1312-34A") - - val writtenJson = json.idFormat.write(referenceId) - writtenJson.prettyPrint should be("\"1312-34A\"") - - val parsedId = json.idFormat.read(writtenJson) - parsedId should be(referenceId) - } - } - - "Json format for @@" should { - "read and write correct JSON" in { - trait Irrelevant - val reference = Id[JsonTest]("SomeID").tagged[Irrelevant] - - val format = json.taggedFormat[Id[JsonTest], Irrelevant] - - val writtenJson = format.write(reference) - writtenJson shouldBe JsString("SomeID") - - val parsedId: Id[JsonTest] @@ Irrelevant = format.read(writtenJson) - parsedId shouldBe reference - } - - "read and write correct JSON when there's an implicit conversion defined" in { - val input = " some string " - - JsString(input).convertTo[String @@ Trimmed] shouldBe input.trim() - - val trimmed: String @@ Trimmed = input - trimmed.toJson shouldBe JsString(trimmed) - } - } - - "Json format for Name" should { - "read and write correct JSON" in { - - val referenceName = Name[String]("Homer") - - val writtenJson = json.nameFormat.write(referenceName) - writtenJson.prettyPrint should be("\"Homer\"") - - val parsedName = json.nameFormat.read(writtenJson) - parsedName should be(referenceName) - } - - "read and write correct JSON for Name @@ Trimmed" in { - trait Irrelevant - JsString(" some name ").convertTo[Name[Irrelevant] @@ Trimmed] shouldBe Name[Irrelevant]("some name") - - val trimmed: Name[Irrelevant] @@ Trimmed = Name(" some name ") - trimmed.toJson shouldBe JsString("some name") - } - } - - "Json format for NonEmptyName" should { - "read and write correct JSON" in { - - val jsonFormat = json.nonEmptyNameFormat[String] - - val referenceNonEmptyName = NonEmptyName[String](refineMV[NonEmpty]("Homer")) - - val writtenJson = jsonFormat.write(referenceNonEmptyName) - writtenJson.prettyPrint should be("\"Homer\"") - - val parsedNonEmptyName = jsonFormat.read(writtenJson) - parsedNonEmptyName should be(referenceNonEmptyName) - } - } - - "Json format for Time" should { - "read and write correct JSON" in { - - val referenceTime = new SystemTimeProvider().currentTime() - - val writtenJson = json.timeFormat.write(referenceTime) - writtenJson.prettyPrint should be("{\n \"timestamp\": " + referenceTime.millis + "\n}") - - val parsedTime = json.timeFormat.read(writtenJson) - parsedTime should be(referenceTime) - } - - "read from inputs compatible with Instant" in { - val referenceTime = new SystemTimeProvider().currentTime() - - val jsons = Seq(JsNumber(referenceTime.millis), JsString(Instant.ofEpochMilli(referenceTime.millis).toString)) - - forAll(jsons) { json => - json.convertTo[Time] shouldBe referenceTime - } - } - } - - "Json format for TimeOfDay" should { - "read and write correct JSON" in { - val utcTimeZone = java.util.TimeZone.getTimeZone("UTC") - val referenceTimeOfDay = TimeOfDay.parseTimeString(utcTimeZone)("08:00:00") - val writtenJson = json.timeOfDayFormat.write(referenceTimeOfDay) - writtenJson should be("""{"localTime":"08:00:00","timeZone":"UTC"}""".parseJson) - val parsed = json.timeOfDayFormat.read(writtenJson) - parsed should be(referenceTimeOfDay) - } - } - - "Json format for Date" should { - "read and write correct JSON" in { - import date._ - - val referenceDate = Date(1941, Month.DECEMBER, 7) - - val writtenJson = json.dateFormat.write(referenceDate) - writtenJson.prettyPrint should be("\"1941-12-07\"") - - val parsedDate = json.dateFormat.read(writtenJson) - parsedDate should be(referenceDate) - } - } - - "Json format for java.time.Instant" should { - - val isoString = "2018-08-08T08:08:08.888Z" - val instant = Instant.parse(isoString) - - "read correct JSON when value is an epoch milli number" in { - JsNumber(instant.toEpochMilli).convertTo[Instant] shouldBe instant - } - - "read correct JSON when value is an ISO timestamp string" in { - JsString(isoString).convertTo[Instant] shouldBe instant - } - - "read correct JSON when value is an object with nested 'timestamp'/millis field" in { - val json = JsObject( - "timestamp" -> JsNumber(instant.toEpochMilli) - ) - - json.convertTo[Instant] shouldBe instant - } - - "write correct JSON" in { - instant.toJson shouldBe JsString(isoString) - } - } - - "Path matcher for Instant" should { - - val isoString = "2018-08-08T08:08:08.888Z" - val instant = Instant.parse(isoString) - - val matcher = PathMatcher("foo") / InstantInPath / - - "read instant from millis" in { - matcher(Uri.Path("foo") / ("+" + instant.toEpochMilli) / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) - } - - "read instant from ISO timestamp string" in { - matcher(Uri.Path("foo") / isoString / "bar") shouldBe Matched(Uri.Path("bar"), Tuple1(instant)) - } - } - - "Json format for java.time.LocalDate" should { - - "read and write correct JSON" in { - val dateString = "2018-08-08" - val date = LocalDate.parse(dateString) - - date.toJson shouldBe JsString(dateString) - JsString(dateString).convertTo[LocalDate] shouldBe date - } - } - - "Json format for Revision" should { - "read and write correct JSON" in { - - val referenceRevision = Revision[String]("037e2ec0-8901-44ac-8e53-6d39f6479db4") - - val writtenJson = json.revisionFormat.write(referenceRevision) - writtenJson.prettyPrint should be("\"" + referenceRevision.id + "\"") - - val parsedRevision = json.revisionFormat.read(writtenJson) - parsedRevision should be(referenceRevision) - } - } - - "Json format for Email" should { - "read and write correct JSON" in { - - val referenceEmail = Email("test", "drivergrp.com") - - val writtenJson = json.emailFormat.write(referenceEmail) - writtenJson should be("\"test@drivergrp.com\"".parseJson) - - val parsedEmail = json.emailFormat.read(writtenJson) - parsedEmail should be(referenceEmail) - } - } - - "Json format for PhoneNumber" should { - "read and write correct JSON" in { - - val referencePhoneNumber = PhoneNumber("1", "4243039608") - - val writtenJson = json.phoneNumberFormat.write(referencePhoneNumber) - writtenJson should be("""{"countryCode":"1","number":"4243039608"}""".parseJson) - - val parsedPhoneNumber = json.phoneNumberFormat.read(writtenJson) - parsedPhoneNumber should be(referencePhoneNumber) - } - - "reject an invalid phone number" in { - val phoneJson = """{"countryCode":"1","number":"111-111-1113"}""".parseJson - - intercept[DeserializationException] { - json.phoneNumberFormat.read(phoneJson) - }.getMessage shouldBe "Invalid phone number" - } - - "parse phone number from string" in { - JsString("+14243039608").convertTo[PhoneNumber] shouldBe PhoneNumber("1", "4243039608") - } - } - - "Path matcher for PhoneNumber" should { - "read valid phone number" in { - val string = "+14243039608x23" - val phone = PhoneNumber("1", "4243039608", Some("23")) - - val matcher = PathMatcher("foo") / PhoneInPath - - matcher(Uri.Path("foo") / string / "bar") shouldBe Matched(Uri.Path./("bar"), Tuple1(phone)) - } - } - - "Json format for ADT mappings" should { - "read and write correct JSON" in { - - sealed trait EnumVal - case object Val1 extends EnumVal - case object Val2 extends EnumVal - case object Val3 extends EnumVal - - val format = new EnumJsonFormat[EnumVal]("a" -> Val1, "b" -> Val2, "c" -> Val3) - - val referenceEnumValue1 = Val2 - val referenceEnumValue2 = Val3 - - val writtenJson1 = format.write(referenceEnumValue1) - writtenJson1.prettyPrint should be("\"b\"") - - val writtenJson2 = format.write(referenceEnumValue2) - writtenJson2.prettyPrint should be("\"c\"") - - val parsedEnumValue1 = format.read(writtenJson1) - val parsedEnumValue2 = format.read(writtenJson2) - - parsedEnumValue1 should be(referenceEnumValue1) - parsedEnumValue2 should be(referenceEnumValue2) - } - } - - "Json format for Enums (external)" should { - "read and write correct JSON" in { - - sealed trait MyEnum extends EnumEntry - object MyEnum extends Enum[MyEnum] { - case object Val1 extends MyEnum - case object `Val 2` extends MyEnum - case object `Val/3` extends MyEnum - - val values: IndexedSeq[MyEnum] = findValues - } - - val format = new enumeratum.EnumJsonFormat(MyEnum) - - val referenceEnumValue1 = MyEnum.`Val 2` - val referenceEnumValue2 = MyEnum.`Val/3` - - val writtenJson1 = format.write(referenceEnumValue1) - writtenJson1 shouldBe JsString("Val 2") - - val writtenJson2 = format.write(referenceEnumValue2) - writtenJson2 shouldBe JsString("Val/3") - - val parsedEnumValue1 = format.read(writtenJson1) - val parsedEnumValue2 = format.read(writtenJson2) - - parsedEnumValue1 shouldBe referenceEnumValue1 - parsedEnumValue2 shouldBe referenceEnumValue2 - - intercept[DeserializationException] { - format.read(JsString("Val4")) - }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" - } - } - - "Json format for Enums (automatic)" should { - "read and write correct JSON and not require import" in { - - sealed trait MyEnum extends EnumEntry - object MyEnum extends Enum[MyEnum] with HasJsonFormat[MyEnum] { - case object Val1 extends MyEnum - case object `Val 2` extends MyEnum - case object `Val/3` extends MyEnum - - val values: IndexedSeq[MyEnum] = findValues - } - - val referenceEnumValue1: MyEnum = MyEnum.`Val 2` - val referenceEnumValue2: MyEnum = MyEnum.`Val/3` - - val writtenJson1 = referenceEnumValue1.toJson - writtenJson1 shouldBe JsString("Val 2") - - val writtenJson2 = referenceEnumValue2.toJson - writtenJson2 shouldBe JsString("Val/3") - - import spray.json._ - - val parsedEnumValue1 = writtenJson1.prettyPrint.parseJson.convertTo[MyEnum] - val parsedEnumValue2 = writtenJson2.prettyPrint.parseJson.convertTo[MyEnum] - - parsedEnumValue1 should be(referenceEnumValue1) - parsedEnumValue2 should be(referenceEnumValue2) - - intercept[DeserializationException] { - JsString("Val4").convertTo[MyEnum] - }.getMessage shouldBe "Unexpected value Val4. Expected one of: [Val1, Val 2, Val/3]" - } - } - - // Should be defined outside of case to have a TypeTag - case class CustomWrapperClass(value: Int) - - "Json format for Value classes" should { - "read and write correct JSON" in { - - val format = new ValueClassFormat[CustomWrapperClass](v => BigDecimal(v.value), d => CustomWrapperClass(d.toInt)) - - val referenceValue1 = CustomWrapperClass(-2) - val referenceValue2 = CustomWrapperClass(10) - - val writtenJson1 = format.write(referenceValue1) - writtenJson1.prettyPrint should be("-2") - - val writtenJson2 = format.write(referenceValue2) - writtenJson2.prettyPrint should be("10") - - val parsedValue1 = format.read(writtenJson1) - val parsedValue2 = format.read(writtenJson2) - - parsedValue1 should be(referenceValue1) - parsedValue2 should be(referenceValue2) - } - } - - "Json format for classes GADT" should { - "read and write correct JSON" in { - - import CustomGADT._ - import DefaultJsonProtocol._ - implicit val case1Format = jsonFormat1(GadtCase1) - implicit val case2Format = jsonFormat1(GadtCase2) - implicit val case3Format = jsonFormat1(GadtCase3) - - val format = GadtJsonFormat.create[CustomGADT]("gadtTypeField") { - case _: CustomGADT.GadtCase1 => "case1" - case _: CustomGADT.GadtCase2 => "case2" - case _: CustomGADT.GadtCase3 => "case3" - } { - case "case1" => case1Format - case "case2" => case2Format - case "case3" => case3Format - } - - val referenceValue1 = CustomGADT.GadtCase1("4") - val referenceValue2 = CustomGADT.GadtCase2("Hi!") - - val writtenJson1 = format.write(referenceValue1) - writtenJson1 should be("{\n \"field\": \"4\",\n\"gadtTypeField\": \"case1\"\n}".parseJson) - - val writtenJson2 = format.write(referenceValue2) - writtenJson2 should be("{\"field\":\"Hi!\",\"gadtTypeField\":\"case2\"}".parseJson) - - val parsedValue1 = format.read(writtenJson1) - val parsedValue2 = format.read(writtenJson2) - - parsedValue1 should be(referenceValue1) - parsedValue2 should be(referenceValue2) - } - } - - "Json format for a Refined value" should { - "read and write correct JSON" in { - - val jsonFormat = json.refinedJsonFormat[Int, Positive] - - val referenceRefinedNumber = refineMV[Positive](42) - - val writtenJson = jsonFormat.write(referenceRefinedNumber) - writtenJson should be("42".parseJson) - - val parsedRefinedNumber = jsonFormat.read(writtenJson) - parsedRefinedNumber should be(referenceRefinedNumber) - } - } - - "InetAddress format" should { - "read and write correct JSON" in { - val address = InetAddress.getByName("127.0.0.1") - val json = inetAddressFormat.write(address) - - json shouldBe JsString("127.0.0.1") - - val parsed = inetAddressFormat.read(json) - parsed shouldBe address - } - - "throw a DeserializationException for an invalid IP Address" in { - assertThrows[DeserializationException] { - val invalidAddress = JsString("foobar:") - inetAddressFormat.read(invalidAddress) - } - } - } - - "AuthCredentials format" should { - "read and write correct JSON" in { - val email = Email("someone", "noehere.com") - val phoneId = PhoneNumber.parse("1 207 8675309") - val password = "nopassword" - - phoneId.isDefined should be(true) // test this real quick - - val emailAuth = AuthCredentials(email.toString, password) - val pnAuth = AuthCredentials(phoneId.get.toString, password) - - val emailWritten = authCredentialsFormat.write(emailAuth) - emailWritten should be("""{"identifier":"someone@noehere.com","password":"nopassword"}""".parseJson) - - val phoneWritten = authCredentialsFormat.write(pnAuth) - phoneWritten should be("""{"identifier":"+1 2078675309","password":"nopassword"}""".parseJson) - - val identifierEmailParsed = - authCredentialsFormat.read("""{"identifier":"someone@nowhere.com","password":"nopassword"}""".parseJson) - var written = authCredentialsFormat.write(identifierEmailParsed) - written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) - - val emailEmailParsed = - authCredentialsFormat.read("""{"email":"someone@nowhere.com","password":"nopassword"}""".parseJson) - written = authCredentialsFormat.write(emailEmailParsed) - written should be("{\"identifier\":\"someone@nowhere.com\",\"password\":\"nopassword\"}".parseJson) - - } - } - - "CountryCode format" should { - "read and write correct JSON" in { - val samples = Seq( - "US" -> CountryCode.US, - "CN" -> CountryCode.CN, - "AT" -> CountryCode.AT - ) - - forAll(samples) { - case (serialized, enumValue) => - countryCodeFormat.write(enumValue) shouldBe JsString(serialized) - countryCodeFormat.read(JsString(serialized)) shouldBe enumValue - } - } - } - - "CurrencyCode format" should { - "read and write correct JSON" in { - val samples = Seq( - "USD" -> CurrencyCode.USD, - "CNY" -> CurrencyCode.CNY, - "EUR" -> CurrencyCode.EUR - ) - - forAll(samples) { - case (serialized, enumValue) => - currencyCodeFormat.write(enumValue) shouldBe JsString(serialized) - currencyCodeFormat.read(JsString(serialized)) shouldBe enumValue - } - } - } - -} diff --git a/src/test/scala/xyz/driver/core/TestTypes.scala b/src/test/scala/xyz/driver/core/TestTypes.scala deleted file mode 100644 index bb25deb..0000000 --- a/src/test/scala/xyz/driver/core/TestTypes.scala +++ /dev/null @@ -1,14 +0,0 @@ -package xyz.driver.core - -object TestTypes { - - sealed trait CustomGADT { - val field: String - } - - object CustomGADT { - final case class GadtCase1(field: String) extends CustomGADT - final case class GadtCase2(field: String) extends CustomGADT - final case class GadtCase3(field: String) extends CustomGADT - } -} diff --git a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala b/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala deleted file mode 100644 index 324c8d8..0000000 --- a/src/test/scala/xyz/driver/core/rest/DriverAppTest.scala +++ /dev/null @@ -1,89 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{HttpMethod, StatusCodes} -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import com.typesafe.config.ConfigFactory -import org.scalatest.{AsyncFlatSpec, Matchers} -import xyz.driver.core.app.{DriverApp, SimpleModule} - -class DriverAppTest extends AsyncFlatSpec with ScalatestRouteTest with Matchers with Directives { - val config = ConfigFactory.parseString(""" - |application { - | cors { - | allowedOrigins: ["example.com"] - | } - |} - """.stripMargin).withFallback(ConfigFactory.load) - - val origin = Origin(HttpOrigin("https", Host("example.com"))) - val allowedOrigins = Set(HttpOrigin("https", Host("example.com"))) - val allowedMethods: collection.immutable.Seq[HttpMethod] = { - import akka.http.scaladsl.model.HttpMethods._ - collection.immutable.Seq(GET, PUT, POST, PATCH, DELETE, OPTIONS, TRACE) - } - - import scala.reflect.runtime.universe.typeOf - class TestApp(testRoute: Route) - extends DriverApp( - appName = "test-app", - version = "0.0.1", - gitHash = "deadb33f", - modules = Seq(new SimpleModule("test-module", theRoute = testRoute, routeType = typeOf[DriverApp])), - config = config, - log = xyz.driver.core.logging.NoLogger - ) - - it should "respond with the correct CORS headers for the swagger OPTIONS route" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Options(s"/api-docs/swagger.json").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - header[`Access-Control-Allow-Methods`].get.methods should contain theSameElementsAs allowedMethods - } - } - - it should "respond with the correct CORS headers for the test route" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - } - } - - it should "respond with the correct CORS headers for a concatenated route" in { - val route = new TestApp(get(complete(StatusCodes.OK)) ~ post(complete(StatusCodes.OK))) - Post(s"/api/v1/test").withHeaders(origin) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange(allowedOrigins.toSeq: _*))) - } - } - - it should "allow subdomains of allowed origin suffixes" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("foo.example.com")))) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOrigin("https", Host("foo.example.com")))) - } - } - - it should "respond with default domains for invalid origins" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") - .withHeaders(Origin(HttpOrigin("https", Host("invalid.foo.bar.com")))) ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - headers should contain(`Access-Control-Allow-Origin`(HttpOriginRange.*)) - } - } - - it should "respond with Pragma and Cache-Control (no-cache) headers" in { - val route = new TestApp(get(complete(StatusCodes.OK))) - Get(s"/api/v1/test") ~> route.appRoute ~> check { - status shouldBe StatusCodes.OK - header("Pragma").map(_.value()) should contain("no-cache") - header[`Cache-Control`].map(_.value()) should contain("no-cache") - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala b/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala deleted file mode 100644 index cc0019a..0000000 --- a/src/test/scala/xyz/driver/core/rest/DriverRouteTest.scala +++ /dev/null @@ -1,121 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model.StatusCodes -import akka.http.scaladsl.model.headers.Connection -import akka.http.scaladsl.server.Directives.{complete => akkaComplete} -import akka.http.scaladsl.server.{Directives, RejectionHandler, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import com.typesafe.scalalogging.Logger -import org.scalatest.{AsyncFlatSpec, Matchers} -import xyz.driver.core.json.serviceExceptionFormat -import xyz.driver.core.logging.NoLogger -import xyz.driver.core.rest.errors._ - -import scala.concurrent.Future - -class DriverRouteTest - extends AsyncFlatSpec with ScalatestRouteTest with SprayJsonSupport with Matchers with Directives { - class TestRoute(override val route: Route) extends DriverRoute { - override def log: Logger = NoLogger - } - - "DriverRoute" should "respond with 200 OK for a basic route" in { - val route = new TestRoute(akkaComplete(StatusCodes.OK)) - - Get("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.OK - } - } - - it should "respond with a 401 for an InvalidInputException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](InvalidInputException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - responseAs[ServiceException] shouldBe InvalidInputException() - } - } - - it should "respond with a 403 for InvalidActionException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](InvalidActionException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.Forbidden - responseAs[ServiceException] shouldBe InvalidActionException() - } - } - - it should "respond with a 404 for ResourceNotFoundException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](ResourceNotFoundException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - responseAs[ServiceException] shouldBe ResourceNotFoundException() - } - } - - it should "respond with a 500 for ExternalServiceException" in { - val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", None) - val route = new TestRoute(akkaComplete(Future.failed[String](error))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.InternalServerError - responseAs[ServiceException] shouldBe error - } - } - - it should "allow pass-through of external service exceptions" in { - val innerError = InvalidInputException() - val error = ExternalServiceException("GET /api/v1/users/", "Permission denied", Some(innerError)) - val future = Future.failed[String](error) - val route = new TestRoute(akkaComplete(future.passThroughExternalServiceException)) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - responseAs[ServiceException] shouldBe innerError - } - } - - it should "respond with a 503 for ExternalServiceTimeoutException" in { - val error = ExternalServiceTimeoutException("GET /api/v1/users/") - val route = new TestRoute(akkaComplete(Future.failed[String](error))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.GatewayTimeout - responseAs[ServiceException] shouldBe error - } - } - - it should "respond with a 500 for DatabaseException" in { - val route = new TestRoute(akkaComplete(Future.failed[String](DatabaseException()))) - - Post("/api/v1/foo/bar") ~> route.routeWithDefaults ~> check { - handled shouldBe true - status shouldBe StatusCodes.InternalServerError - responseAs[ServiceException] shouldBe DatabaseException() - } - } - - it should "add a `Connection: close` header to avoid clashing with envoy's timeouts" in { - val rejectionHandler = RejectionHandler.newBuilder().handleNotFound(complete(StatusCodes.NotFound)).result() - val route = new TestRoute(handleRejections(rejectionHandler)((get & path("foo"))(complete("OK")))) - - Get("/foo") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.OK - headers should contain(Connection("close")) - } - - Get("/bar") ~> route.routeWithDefaults ~> check { - status shouldBe StatusCodes.NotFound - headers should contain(Connection("close")) - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala b/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala deleted file mode 100644 index 987717d..0000000 --- a/src/test/scala/xyz/driver/core/rest/PatchDirectivesTest.scala +++ /dev/null @@ -1,101 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.`Content-Type` -import akka.http.scaladsl.server.{Directives, Route} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import org.scalatest.{FlatSpec, Matchers} -import spray.json._ -import xyz.driver.core.{Id, Name} -import xyz.driver.core.json._ - -import scala.concurrent.Future - -class PatchDirectivesTest - extends FlatSpec with Matchers with ScalatestRouteTest with SprayJsonSupport with DefaultJsonProtocol - with Directives with PatchDirectives { - case class Bar(name: Name[Bar], size: Int) - case class Foo(id: Id[Foo], name: Name[Foo], rank: Int, bar: Option[Bar]) - implicit val barFormat: RootJsonFormat[Bar] = jsonFormat2(Bar) - implicit val fooFormat: RootJsonFormat[Foo] = jsonFormat4(Foo) - - val testFoo: Foo = Foo(Id("1"), Name(s"Foo"), 1, Some(Bar(Name("Bar"), 10))) - - def route(retrieve: => Future[Option[Foo]]): Route = - Route.seal(path("api" / "v1" / "foos" / IdInPath[Foo]) { fooId => - entity(as[Patchable[Foo]]) { fooPatchable => - mergePatch(fooPatchable, retrieve) { updatedFoo => - complete(updatedFoo) - } - } - }) - - val MergePatchContentType = ContentType(`application/merge-patch+json`) - val ContentTypeHeader = `Content-Type`(MergePatchContentType) - def jsonEntity(json: String, contentType: ContentType.NonBinary = MergePatchContentType): RequestEntity = - HttpEntity(contentType, json) - - "PatchSupport" should "allow partial updates to an existing object" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4) - } - } - - it should "merge deeply nested objects" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4, "bar": {"name": "My Bar"}}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(rank = 4, bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 404 if the object is not found" in { - val fooRetrieve = Future.successful(None) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - status shouldBe StatusCodes.NotFound - } - } - - it should "handle nulls on optional values correctly" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": null}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = None) - } - } - - it should "handle optional values correctly when old value is null" in { - val fooRetrieve = Future.successful(Some(testFoo.copy(bar = None))) - - Patch("/api/v1/foos/1", jsonEntity("""{"bar": {"name": "My Bar","size":10}}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - responseAs[Foo] shouldBe testFoo.copy(bar = Some(Bar(Name("My Bar"), 10))) - } - } - - it should "return a 400 for nulls on non-optional values" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": null}""")) ~> route(fooRetrieve) ~> check { - handled shouldBe true - status shouldBe StatusCodes.BadRequest - } - } - - it should "return a 415 for incorrect Content-Type" in { - val fooRetrieve = Future.successful(Some(testFoo)) - - Patch("/api/v1/foos/1", jsonEntity("""{"rank": 4}""", ContentTypes.`application/json`)) ~> route(fooRetrieve) ~> check { - status shouldBe StatusCodes.UnsupportedMediaType - responseAs[String] should include("application/merge-patch+json") - } - } -} diff --git a/src/test/scala/xyz/driver/core/rest/RestTest.scala b/src/test/scala/xyz/driver/core/rest/RestTest.scala deleted file mode 100644 index 19e4ed1..0000000 --- a/src/test/scala/xyz/driver/core/rest/RestTest.scala +++ /dev/null @@ -1,151 +0,0 @@ -package xyz.driver.core.rest - -import akka.http.scaladsl.model.StatusCodes -import akka.http.scaladsl.server.{Directives, Route, ValidationRejection} -import akka.http.scaladsl.testkit.ScalatestRouteTest -import akka.util.ByteString -import org.scalatest.{Matchers, WordSpec} -import xyz.driver.core.rest - -import scala.concurrent.Future -import scala.util.Random - -class RestTest extends WordSpec with Matchers with ScalatestRouteTest with Directives { - "`escapeScriptTags` function" should { - "escape script tags properly" in { - val dirtyString = "</sc----</sc----</sc" - val cleanString = "--------------------" - - (escapeScriptTags(ByteString(dirtyString)).utf8String) should be(dirtyString.replace("</sc", "< /sc")) - - (escapeScriptTags(ByteString(cleanString)).utf8String) should be(cleanString) - } - } - - "paginated directive" should { - val route: Route = rest.paginated { paginated => - complete(StatusCodes.OK -> s"${paginated.pageNumber},${paginated.pageSize}") - } - "accept a pagination" in { - Get("/?pageNumber=2&pageSize=42") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "2,42") - } - } - "provide a default pagination" in { - Get("/") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "1,100") - } - } - "provide default values for a partial pagination" in { - Get("/?pageSize=2") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "1,2") - } - } - "reject an invalid pagination" in { - Get("/?pageNumber=-1") ~> route ~> check { - assert(rejection.isInstanceOf[ValidationRejection]) - } - } - } - - "optional paginated directive" should { - val route: Route = rest.optionalPagination { paginated => - complete(StatusCodes.OK -> paginated.map(p => s"${p.pageNumber},${p.pageSize}").getOrElse("no pagination")) - } - "accept a pagination" in { - Get("/?pageNumber=2&pageSize=42") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "2,42") - } - } - "without pagination" in { - Get("/") ~> route ~> check { - assert(status == StatusCodes.OK) - assert(entityAs[String] == "no pagination") - } - } - "reject an invalid pagination" in { - Get("/?pageNumber=1") ~> route ~> check { - assert(rejection.isInstanceOf[ValidationRejection]) - } - } - } - - "completeWithPagination directive" when { - import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ - import spray.json.DefaultJsonProtocol._ - - val data = Seq.fill(103)(Random.alphanumeric.take(10).mkString) - val route: Route = - parameter('empty.as[Boolean] ? false) { isEmpty => - completeWithPagination[String] { - case Some(pagination) if isEmpty => - Future.successful(ListResponse(Seq(), 0, Some(pagination))) - case Some(pagination) => - val filtered = data.slice(pagination.offset, pagination.offset + pagination.pageSize) - Future.successful(ListResponse(filtered, data.size, Some(pagination))) - case None if isEmpty => Future.successful(ListResponse(Seq(), 0, None)) - case None => Future.successful(ListResponse(data, data.size, None)) - } - } - - "pagination is defined" should { - "return a response with pagination headers" in { - Get("/?pageNumber=2&pageSize=10") ~> route ~> check { - responseAs[Seq[String]] shouldBe data.slice(10, 20) - header(ContextHeaders.ResourceCount).map(_.value) should contain("103") - header(ContextHeaders.PageCount).map(_.value) should contain("11") - } - } - - "disallow pageSize <= 0" in { - Get("/?pageNumber=2&pageSize=0") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - - Get("/?pageNumber=2&pageSize=-1") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - } - - "disallow pageNumber <= 0" in { - Get("/?pageNumber=0&pageSize=10") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - - Get("/?pageNumber=-1&pageSize=10") ~> route ~> check { - rejection shouldBe a[ValidationRejection] - } - } - - "return PageCount == 0 if returning an empty list" in { - Get("/?empty=true&pageNumber=2&pageSize=10") ~> route ~> check { - responseAs[Seq[String]] shouldBe empty - header(ContextHeaders.ResourceCount).map(_.value) should contain("0") - header(ContextHeaders.PageCount).map(_.value) should contain("0") - } - } - } - - "pagination is not defined" should { - "return a response with pagination headers and PageCount == 1" in { - Get("/") ~> route ~> check { - responseAs[Seq[String]] shouldBe data - header(ContextHeaders.ResourceCount).map(_.value) should contain("103") - header(ContextHeaders.PageCount).map(_.value) should contain("1") - } - } - - "return PageCount == 0 if returning an empty list" in { - Get("/?empty=true") ~> route ~> check { - responseAs[Seq[String]] shouldBe empty - header(ContextHeaders.ResourceCount).map(_.value) should contain("0") - header(ContextHeaders.PageCount).map(_.value) should contain("0") - } - } - } - } -} |