aboutsummaryrefslogtreecommitdiff
path: root/core-rest/src/main/scala/xyz/driver/core/rest/serviceRequestContext.scala
blob: d2e4bc35a7aad8f41674ce5c9217f2613216d1e8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package xyz.driver.core.rest

import java.net.InetAddress

import xyz.driver.core.auth.{AuthToken, PermissionsToken, User}
import xyz.driver.core.generators
import scalaz.Scalaz.{mapEqual, stringInstance}
import scalaz.syntax.equal._
import xyz.driver.core.reporting.SpanContext
import xyz.driver.core.rest.auth.AuthProvider
import xyz.driver.core.rest.headers.Traceparent

import scala.util.Try

class ServiceRequestContext(
    val trackingId: String = generators.nextUuid().toString,
    val originatingIp: Option[InetAddress] = None,
    val contextHeaders: Map[String, String] = Map.empty[String, String]) {
  def authToken: Option[AuthToken] =
    contextHeaders.get(AuthProvider.AuthenticationTokenHeader).map(AuthToken.apply)

  def permissionsToken: Option[PermissionsToken] =
    contextHeaders.get(AuthProvider.PermissionsTokenHeader).map(PermissionsToken.apply)

  def withAuthToken(authToken: AuthToken): ServiceRequestContext =
    new ServiceRequestContext(
      trackingId,
      originatingIp,
      contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value)
    )

  def withAuthenticatedUser[U <: User](authToken: AuthToken, user: U): AuthorizedServiceRequestContext[U] =
    new AuthorizedServiceRequestContext(
      trackingId,
      originatingIp,
      contextHeaders.updated(AuthProvider.AuthenticationTokenHeader, authToken.value),
      user
    )

  override def hashCode(): Int =
    Seq[Any](trackingId, originatingIp, contextHeaders)
      .foldLeft(31)((result, obj) => 31 * result + obj.hashCode())

  override def equals(obj: Any): Boolean = obj match {
    case ctx: ServiceRequestContext =>
      trackingId === ctx.trackingId &&
        originatingIp == originatingIp &&
        contextHeaders === ctx.contextHeaders
    case _ => false
  }

  def spanContext: SpanContext = {
    val validHeader = Try {
      contextHeaders(Traceparent.name)
    }.flatMap { value =>
      Traceparent.parse(value)
    }
    validHeader.map(_.spanContext).getOrElse(SpanContext.fresh())
  }

  override def toString: String = s"ServiceRequestContext($trackingId, $contextHeaders)"
}

class AuthorizedServiceRequestContext[U <: User](
    override val trackingId: String = generators.nextUuid().toString,
    override val originatingIp: Option[InetAddress] = None,
    override val contextHeaders: Map[String, String] = Map.empty[String, String],
    val authenticatedUser: U)
    extends ServiceRequestContext {

  def withPermissionsToken(permissionsToken: PermissionsToken): AuthorizedServiceRequestContext[U] =
    new AuthorizedServiceRequestContext[U](
      trackingId,
      originatingIp,
      contextHeaders.updated(AuthProvider.PermissionsTokenHeader, permissionsToken.value),
      authenticatedUser)

  override def hashCode(): Int = 31 * super.hashCode() + authenticatedUser.hashCode()

  override def equals(obj: Any): Boolean = obj match {
    case ctx: AuthorizedServiceRequestContext[U] => super.equals(ctx) && ctx.authenticatedUser == authenticatedUser
    case _                                       => false
  }

  override def toString: String =
    s"AuthorizedServiceRequestContext($trackingId, $contextHeaders, $authenticatedUser)"
}