diff options
Diffstat (limited to 'src/main/scala/xyz/driver/core/rest/package.scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/rest/package.scala | 323 |
1 files changed, 0 insertions, 323 deletions
diff --git a/src/main/scala/xyz/driver/core/rest/package.scala b/src/main/scala/xyz/driver/core/rest/package.scala deleted file mode 100644 index 34a4a9d..0000000 --- a/src/main/scala/xyz/driver/core/rest/package.scala +++ /dev/null @@ -1,323 +0,0 @@ -package xyz.driver.core.rest - -import java.net.InetAddress - -import akka.http.scaladsl.marshalling.{ToEntityMarshaller, ToResponseMarshallable} -import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.server.Directives._ -import akka.http.scaladsl.server._ -import akka.http.scaladsl.unmarshalling.Unmarshal -import akka.stream.Materializer -import akka.stream.scaladsl.Flow -import akka.util.ByteString -import scalaz.Scalaz.{intInstance, stringInstance} -import scalaz.syntax.equal._ -import scalaz.{Functor, OptionT} -import xyz.driver.core.rest.auth.AuthProvider -import xyz.driver.core.rest.errors.ExternalServiceException -import xyz.driver.core.rest.headers.Traceparent -import xyz.driver.tracing.TracingDirectives - -import scala.concurrent.{ExecutionContext, Future} -import scala.util.Try - -trait Service - -object Service - -trait HttpClient { - def makeRequest(request: HttpRequest): Future[HttpResponse] -} - -trait ServiceTransport { - - def sendRequestGetResponse(context: ServiceRequestContext)(requestStub: HttpRequest): Future[HttpResponse] - - def sendRequest(context: ServiceRequestContext)(requestStub: HttpRequest)( - implicit mat: Materializer): Future[Unmarshal[ResponseEntity]] -} - -sealed trait SortingOrder -object SortingOrder { - case object Asc extends SortingOrder - case object Desc extends SortingOrder -} - -final case class SortingField(name: String, sortingOrder: SortingOrder) -final case class Sorting(sortingFields: Seq[SortingField]) - -final case class Pagination(pageSize: Int, pageNumber: Int) { - require(pageSize > 0, "Page size must be greater than zero") - require(pageNumber > 0, "Page number must be greater than zero") - - def offset: Int = pageSize * (pageNumber - 1) -} - -final case class ListResponse[+T](items: Seq[T], meta: ListResponse.Meta) - -object ListResponse { - - def apply[T](items: Seq[T], size: Int, pagination: Option[Pagination]): ListResponse[T] = - ListResponse( - items = items, - meta = ListResponse.Meta(size, pagination.fold(1)(_.pageNumber), pagination.fold(size)(_.pageSize))) - - final case class Meta(itemsCount: Int, pageNumber: Int, pageSize: Int) - - object Meta { - def apply(itemsCount: Int, pagination: Pagination): Meta = - Meta(itemsCount, pagination.pageNumber, pagination.pageSize) - } - -} - -object `package` { - - implicit class FutureExtensions[T](future: Future[T]) { - def passThroughExternalServiceException(implicit executionContext: ExecutionContext): Future[T] = - future.transform(identity, { - case ExternalServiceException(_, _, Some(e)) => e - case t: Throwable => t - }) - } - - implicit class OptionTRestAdditions[T](optionT: OptionT[Future, T]) { - def responseOrNotFound(successCode: StatusCodes.Success = StatusCodes.OK)( - implicit F: Functor[Future], - em: ToEntityMarshaller[T]): Future[ToResponseMarshallable] = { - optionT.fold[ToResponseMarshallable](successCode -> _, StatusCodes.NotFound -> None) - } - } - - object ContextHeaders { - val AuthenticationTokenHeader: String = "Authorization" - val PermissionsTokenHeader: String = "Permissions" - val AuthenticationHeaderPrefix: String = "Bearer" - val ClientFingerprintHeader: String = "X-Client-Fingerprint" - val TrackingIdHeader: String = "X-Trace" - val StacktraceHeader: String = "X-Stacktrace" - val OriginatingIpHeader: String = "X-Forwarded-For" - val ResourceCount: String = "X-Resource-Count" - val PageCount: String = "X-Page-Count" - val TraceHeaderName: String = TracingDirectives.TraceHeaderName - val SpanHeaderName: String = TracingDirectives.SpanHeaderName - } - - val AllowedHeaders: Seq[String] = - Seq( - "Origin", - "X-Requested-With", - "Content-Type", - "Content-Length", - "Accept", - "X-Trace", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Allow-Headers", - "Server", - "Date", - ContextHeaders.ClientFingerprintHeader, - ContextHeaders.TrackingIdHeader, - ContextHeaders.TraceHeaderName, - ContextHeaders.SpanHeaderName, - ContextHeaders.StacktraceHeader, - ContextHeaders.AuthenticationTokenHeader, - ContextHeaders.OriginatingIpHeader, - ContextHeaders.ResourceCount, - ContextHeaders.PageCount, - "X-Frame-Options", - "X-Content-Type-Options", - "Strict-Transport-Security", - AuthProvider.SetAuthenticationTokenHeader, - AuthProvider.SetPermissionsTokenHeader, - "Traceparent" - ) - - def allowOrigin(originHeader: Option[Origin]): `Access-Control-Allow-Origin` = - `Access-Control-Allow-Origin`( - originHeader.fold[HttpOriginRange](HttpOriginRange.*)(h => HttpOriginRange(h.origins: _*))) - - def serviceContext: Directive1[ServiceRequestContext] = { - def fixAuthorizationHeader(headers: Seq[HttpHeader]): collection.immutable.Seq[HttpHeader] = { - headers.map({ header => - if (header.name === ContextHeaders.AuthenticationTokenHeader && !header.value.startsWith( - ContextHeaders.AuthenticationHeaderPrefix)) { - Authorization(OAuth2BearerToken(header.value)) - } else header - })(collection.breakOut) - } - extractClientIP flatMap { remoteAddress => - mapRequest(req => req.withHeaders(fixAuthorizationHeader(req.headers))) tflatMap { _ => - extract(ctx => extractServiceContext(ctx.request, remoteAddress)) - } - } - } - - def respondWithCorsAllowedHeaders: Directive0 = { - respondWithHeaders( - List[HttpHeader]( - `Access-Control-Allow-Headers`(AllowedHeaders: _*), - `Access-Control-Expose-Headers`(AllowedHeaders: _*) - )) - } - - def respondWithCorsAllowedOriginHeaders(origin: Origin): Directive0 = { - respondWithHeader { - `Access-Control-Allow-Origin`(HttpOriginRange(origin.origins: _*)) - } - } - - def respondWithCorsAllowedMethodHeaders(methods: Set[HttpMethod]): Directive0 = { - respondWithHeaders( - List[HttpHeader]( - Allow(methods.to[collection.immutable.Seq]), - `Access-Control-Allow-Methods`(methods.to[collection.immutable.Seq]) - )) - } - - def extractServiceContext(request: HttpRequest, remoteAddress: RemoteAddress): ServiceRequestContext = - new ServiceRequestContext( - extractTrackingId(request), - extractOriginatingIP(request, remoteAddress), - extractContextHeaders(request)) - - def extractTrackingId(request: HttpRequest): String = { - request.headers - .find(_.name === ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } - - def extractFingerprintHash(request: HttpRequest): Option[String] = { - request.headers - .find(_.name === ContextHeaders.ClientFingerprintHeader) - .map(_.value()) - } - - def extractOriginatingIP(request: HttpRequest, remoteAddress: RemoteAddress): Option[InetAddress] = { - request.headers - .find(_.name === ContextHeaders.OriginatingIpHeader) - .flatMap(ipName => Try(InetAddress.getByName(ipName.value)).toOption) - .orElse(remoteAddress.toOption) - } - - def extractStacktrace(request: HttpRequest): Array[String] = - request.headers.find(_.name == ContextHeaders.StacktraceHeader).fold("")(_.value()).split("->") - - def extractContextHeaders(request: HttpRequest): Map[String, String] = { - request.headers - .filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader || - h.name === ContextHeaders.PermissionsTokenHeader || h.name === ContextHeaders.StacktraceHeader || - h.name === ContextHeaders.TraceHeaderName || h.name === ContextHeaders.SpanHeaderName || - h.name === ContextHeaders.OriginatingIpHeader || h.name === ContextHeaders.ClientFingerprintHeader || - h.name === Traceparent.name - } - .map { header => - if (header.name === ContextHeaders.AuthenticationTokenHeader) { - header.name -> header.value.stripPrefix(ContextHeaders.AuthenticationHeaderPrefix).trim - } else { - header.name -> header.value - } - } - .toMap - } - - private[rest] def escapeScriptTags(byteString: ByteString): ByteString = { - @annotation.tailrec - def dirtyIndices(from: Int, descIndices: List[Int]): List[Int] = { - val index = byteString.indexOf('/', from) - if (index === -1) descIndices.reverse - else { - val (init, tail) = byteString.splitAt(index) - if ((init endsWith "<") && (tail startsWith "/sc")) { - dirtyIndices(index + 1, index :: descIndices) - } else { - dirtyIndices(index + 1, descIndices) - } - } - } - - val indices = dirtyIndices(0, Nil) - - indices.headOption.fold(byteString) { head => - val builder = ByteString.newBuilder - builder ++= byteString.take(head) - - (indices :+ byteString.length).sliding(2).foreach { - case Seq(start, end) => - builder += ' ' - builder ++= byteString.slice(start, end) - case Seq(_) => // Should not match; sliding on at least 2 elements - assert(indices.nonEmpty, s"Indices should have been nonEmpty: $indices") - } - builder.result - } - } - - val sanitizeRequestEntity: Directive0 = { - mapRequest(request => request.mapEntity(entity => entity.transformDataBytes(Flow.fromFunction(escapeScriptTags)))) - } - - val paginated: Directive1[Pagination] = - parameters(("pageSize".as[Int] ? 100, "pageNumber".as[Int] ? 1)).as(Pagination) - - private def extractPagination(pageSizeOpt: Option[Int], pageNumberOpt: Option[Int]): Option[Pagination] = - (pageSizeOpt, pageNumberOpt) match { - case (Some(size), Some(number)) => Option(Pagination(size, number)) - case (None, None) => Option.empty[Pagination] - case (_, _) => throw new IllegalArgumentException("Pagination's parameters are incorrect") - } - - val optionalPagination: Directive1[Option[Pagination]] = - parameters(("pageSize".as[Int].?, "pageNumber".as[Int].?)).as(extractPagination) - - def paginationQuery(pagination: Pagination) = - Seq("pageNumber" -> pagination.pageNumber.toString, "pageSize" -> pagination.pageSize.toString) - - def completeWithPagination[T](handler: Option[Pagination] => Future[ListResponse[T]])( - implicit marshaller: ToEntityMarshaller[Seq[T]]): Route = { - optionalPagination { pagination => - onSuccess(handler(pagination)) { - case ListResponse(resultPart, ListResponse.Meta(count, _, pageSize)) => - val pageCount = if (pageSize == 0) 0 else (count / pageSize) + (if (count % pageSize == 0) 0 else 1) - val headers = List( - RawHeader(ContextHeaders.ResourceCount, count.toString), - RawHeader(ContextHeaders.PageCount, pageCount.toString)) - - respondWithHeaders(headers)(complete(ToResponseMarshallable(resultPart))) - } - } - } - - private def extractSorting(sortingString: Option[String]): Sorting = { - val sortingFields = sortingString.fold(Seq.empty[SortingField])( - _.split(",") - .filter(_.length > 0) - .map { sortingParam => - if (sortingParam.startsWith("-")) { - SortingField(sortingParam.substring(1), SortingOrder.Desc) - } else { - val fieldName = if (sortingParam.startsWith("+")) sortingParam.substring(1) else sortingParam - SortingField(fieldName, SortingOrder.Asc) - } - } - .toSeq) - - Sorting(sortingFields) - } - - val sorting: Directive1[Sorting] = parameter("sort".as[String].?).as(extractSorting) - - def sortingQuery(sorting: Sorting): Seq[(String, String)] = { - val sortingString = sorting.sortingFields - .map { sortingField => - sortingField.sortingOrder match { - case SortingOrder.Asc => sortingField.name - case SortingOrder.Desc => s"-${sortingField.name}" - } - } - .mkString(",") - Seq("sort" -> sortingString) - } -} |