package xyz.driver.core.rest import java.net.InetAddress import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers._ import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server._ import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.Materializer import akka.stream.scaladsl.Flow import akka.util.ByteString import scalaz.Scalaz.{intInstance, stringInstance} import scalaz.syntax.equal._ import scalaz.{Functor, OptionT} import xyz.driver.core.rest.auth.AuthProvider import xyz.driver.tracing.TracingDirectives import scala.concurrent.Future import scala.util.Try trait Service object Service trait HttpClient { def makeRequest(request: HttpRequest): Future[HttpResponse] } trait ServiceTransport { def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] } sealed trait SortingOrder object SortingOrder { case object Asc extends SortingOrder case object Desc extends SortingOrder } final case class SortingField(name: String, sortingOrder: SortingOrder) final case class Sorting(sortingFields: Seq[SortingField]) final case class Pagination(pageSize: Int, pageNumber: Int) { require(pageSize > 0, "Page size must be greater than zero") require(pageNumber > 0, "Page number must be greater than zero") def offset: Int = pageSize * (pageNumber - 1) } final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) object ListResponse { def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = ListResponse( items = items, meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) object Meta { def apply(itemsCount: Int, pagination: Pagination): Meta = Meta(itemsCount, pagination.pageNumber, pagination.pageSize) } } object `package` { implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( implicit F: Functor[Future], em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) } } object ContextHeaders { val AuthenticationTokenHeader: String = "Authorization" val PermissionsTokenHeader: String = "Permissions" val AuthenticationHeaderPrefix: String = "Bearer" val ClientFingerprintHeader: String = "X-Client-Fingerprint" val TrackingIdHeader: String = "X-Trace" val StacktraceHeader: String = "X-Stacktrace" val OriginatingIpHeader: String = "X-Forwarded-For" val ResourceCount: String = "X-Resource-Count" val PageCount: String = "X-Page-Count" val TraceHeaderName: String = TracingDirectives.TraceHeaderName val SpanHeaderName: String = TracingDirectives.SpanHeaderName } val AllowedHeaders: Seq[String] = Seq( "Origin", "X-Requested-With", "Content-Type", "Content-Length", "Accept", "X-Trace", "Access-Control-Allow-Methods", "Access-Control-Allow-Origin", "Access-Control-Allow-Headers", "Server", "Date", ContextHeaders.ClientFingerprintHeader, ContextHeaders.TrackingIdHeader, ContextHeaders.TraceHeaderName, ContextHeaders.SpanHeaderName, ContextHeaders.StacktraceHeader, ContextHeaders.AuthenticationTokenHeader, ContextHeaders.OriginatingIpHeader, ContextHeaders.ResourceCount, ContextHeaders.PageCount, "X-Frame-Options", "X-Content-Type-Options", "Strict-Transport-Security", AuthProvider.SetAuthenticationTokenHeader, AuthProvider.SetPermissionsTokenHeader, "Traceparent" ) def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = `Access-Control-Allow-Origin`( originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) def serviceContext: Directive1[ServiceRequestContext] = { def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { headers.map({ header => if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( ContextHeaders.AuthenticationHeaderPrefix)) { Authorization(OAuth2BearerToken(header.value)) } else header })(collection.breakOut) } extractClientIP flatMap { remoteAddress => mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => extract(ctx => extractServiceContext(ctx.request, remoteAddress)) } } } def respondWithCorsAllowedHeaders: Directive0 = { respondWithHeaders( List[HttpHeader]( `Access-Control-Allow-Headers`(AllowedHeaders: _*), `Access-Control-Expose-Headers`(AllowedHeaders: _*) )) } def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { respondWithHeader { `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) } } def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { respondWithHeaders( List[HttpHeader]( Allow(methods.to[collection.immutable.Seq]), `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) )) } def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = new ServiceRequestContext( extractTrackingId(request), extractOriginatingIP(request, remoteAddress), extractContextHeaders(request)) def extractTrackingId(request: HttpRequest): String = { request.headers .find(_.name === ContextHeaders.TrackingIdHeader) .fold(java.util.UUID.randomUUID.toString)(_.value()) } def extractFingerprintHash(request: HttpRequest): Option[String] = { request.headers .find(_.name === ContextHeaders.ClientFingerprintHeader) .map(_.value()) } def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { request.headers .find(_.name === ContextHeaders.OriginatingIpHeader) .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) .orElse(remoteAddress.toOption) } def extractStacktrace(request: HttpRequest): Array[String] = request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") def extractContextHeaders(request: HttpRequest): Map[String, String] = { request.headers .filter { h => h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader } .map { header => if (header.name === ContextHeaders.AuthenticationTokenHeader) { header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim } else { header.name -> header.value } } .toMap } private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { @annotation.tailrec def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { val index = byteString.indexOf('/', from) if (index === -1) descIndices.reverse else { val (init, tail) = byteString.splitAt(index) if ((init endsWith "<") && (tail startsWith "/sc")) { dirtyIndices(index + 1, index :: descIndices) } else { dirtyIndices(index + 1, descIndices) } } } val indices = dirtyIndices(0, Nil) indices.headOption.fold(byteString) { head => val builder = ByteString.newBuilder builder ++= byteString.take(head) (indices :+ byteString.length).sliding(2).foreach { case Seq(start, end) => builder += ' ' builder ++= byteString.slice(start, end) case Seq(_) => // Should not match; sliding on at least 2 elements assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") } builder.result } } val sanitizeRequestEntity: Directive0 = { mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) } val paginated: Directive1[Pagination] = parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = (pageSizeOpt, pageNumberOpt) match { case (Some(size), Some(number)) => Option(Pagination(size, number)) case (None, None) => Option.empty[Pagination] case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") } val optionalPagination: Directive1[Option[Pagination]] = parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) def paginationQuery(pagination: Pagination) = Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { optionalPagination { pagination => onSuccess(handler(pagination)) { case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) val headers = List( RawHeader(ContextHeaders.ResourceCount, count.toString), RawHeader(ContextHeaders.PageCount, pageCount.toString)) respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) } } } private def extractSorting(sortingString: Option[String]): Sorting = { val sortingFields = sortingString.fold(Seq.empty[SortingField])( _.split(",") .filter(_.length > 0) .map { sortingParam => if (sortingParam.startsWith("-")) { SortingField(sortingParam.substring(1), SortingOrder.Desc) } else { val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam SortingField(fieldName, SortingOrder.Asc) } } .toSeq) Sorting(sortingFields) } val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) def sortingQuery(sorting: Sorting): Seq[(String, String)] = { val sortingString = sorting.sortingFields .map { sortingField => sortingField.sortingOrder match { case SortingOrder.Asc => sortingField.name case SortingOrder.Desc => s"-${sortingField.name}" } } .mkString(",") Seq("sort" -> sortingString) } }