aboutsummaryrefslogtreecommitdiff
path: root/metrics/brave-backend/src/main/scala/com/softwaremill/sttp/brave/BraveBackend.scala
blob: 3e366161c3791b6ebdf840e241282c3694527ca5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package com.softwaremill.sttp.brave

import brave.http.{HttpClientAdapter, HttpClientHandler, HttpTracing}
import brave.propagation.{Propagation, TraceContext}
import brave.{Span, Tracing}
import com.softwaremill.sttp.brave.BraveBackend._
import com.softwaremill.sttp.{FollowRedirectsBackend, MonadError, Request, Response, SttpBackend}
import zipkin2.Endpoint

import scala.language.higherKinds

class BraveBackend[R[_], S] private (delegate: SttpBackend[R, S], httpTracing: HttpTracing) extends SttpBackend[R, S] {

  // .asInstanceOf as the create method lacks generics in its return type
  private val handler = HttpClientHandler
    .create(httpTracing, SttpHttpClientAdapter)
    .asInstanceOf[HttpClientHandler[AnyRequest, AnyResponse]]

  private val tracer = httpTracing.tracing().tracer()

  override def send[T](request: Request[T, S]): R[Response[T]] = {
    val span = createSpan(request)
    val tracedRequest = injectTracing(span, request)
    val startedSpan =
      handler.handleSend(NoopInjector, tracedRequest, tracedRequest, span)

    sendAndHandleReceive(startedSpan, tracedRequest)
  }

  override def close(): Unit = delegate.close()

  override def responseMonad: MonadError[R] = delegate.responseMonad

  private def createSpan(request: AnyRequest): Span = {
    request
      .tag(TraceContextRequestTag)
      .map(_.asInstanceOf[TraceContext]) match {
      case None               => handler.nextSpan(request)
      case Some(traceContext) => tracer.newChild(traceContext)
    }
  }

  private def sendAndHandleReceive[T](span: Span, request: Request[T, S]): R[Response[T]] = {
    val spanInScope = tracer.withSpanInScope(span)

    responseMonad.handleError(
      responseMonad.map(delegate.send(request)) { response =>
        spanInScope.close()
        handler.handleReceive(response, null, span)
        response
      }
    ) {
      case e: Exception =>
        spanInScope.close()
        handler.handleReceive(null, e, span)
        responseMonad.error(e)
    }
  }

  private def injectTracing[T](span: Span, request: Request[T, S]): Request[T, S] = {
    /*
    Sadly the Brave API supports only mutable request representations, hence we need to work our way around
    this and inject headers into the traced request with the help of a mutable variable. Later a no-op injector
    is used (during the call to `handleSend`).
     */

    var tracedRequest: Request[T, S] = request

    httpTracing
      .tracing()
      .propagation()
      .injector(new Propagation.Setter[AnyRequest, String] {
        override def put(carrier: AnyRequest, key: String, value: String): Unit = {
          tracedRequest = tracedRequest.header(key, value)
        }
      })
      .inject(span.context(), request)

    tracedRequest
  }
}

object BraveBackend {
  private val NoopInjector = new TraceContext.Injector[Request[_, _]] {
    override def inject(traceContext: TraceContext, carrier: Request[_, _]): Unit = {}
  }

  private val TraceContextRequestTag = classOf[TraceContext].getName

  implicit class RichRequest[T, S](request: Request[T, S]) {
    def tagWithTraceContext(traceContext: TraceContext): Request[T, S] =
      request.tag(TraceContextRequestTag, traceContext)
  }

  type AnyRequest = Request[_, _]
  type AnyResponse = Response[_]

  def apply[R[_], S](delegate: SttpBackend[R, S], tracing: Tracing): SttpBackend[R, S] = {
    apply(delegate, HttpTracing.create(tracing))
  }

  def apply[R[_], S](delegate: SttpBackend[R, S], httpTracing: HttpTracing): SttpBackend[R, S] = {
    // redirects should be handled before brave tracing, hence adding the follow-redirects backend on top
    new FollowRedirectsBackend(new BraveBackend(delegate, httpTracing))
  }
}

object SttpHttpClientAdapter extends HttpClientAdapter[AnyRequest, AnyResponse] {

  override def method(request: AnyRequest): String = request.method.m

  override def url(request: AnyRequest): String = request.uri.toString

  override def requestHeader(request: AnyRequest, name: String): String =
    request.headers.find(_._1.equalsIgnoreCase(name)).map(_._2).orNull

  override def statusCode(response: AnyResponse): Integer = response.code

  override def parseServerAddress(req: AnyRequest, builder: Endpoint.Builder): Boolean = {

    if (builder.parseIp(req.uri.host)) {
      req.uri.port.foreach(builder.port(_))
      true
    } else {
      false
    }
  }
}