diff options
Diffstat (limited to 'metrics')
-rw-r--r-- | metrics/brave/src/main/scala/com/softwaremill/sttp/brave/BraveBackend.scala | 131 | ||||
-rw-r--r-- | metrics/brave/src/test/scala/com/softwaremill/sttp/brave/BraveBackendTest.scala | 144 |
2 files changed, 275 insertions, 0 deletions
diff --git a/metrics/brave/src/main/scala/com/softwaremill/sttp/brave/BraveBackend.scala b/metrics/brave/src/main/scala/com/softwaremill/sttp/brave/BraveBackend.scala new file mode 100644 index 0000000..4549a54 --- /dev/null +++ b/metrics/brave/src/main/scala/com/softwaremill/sttp/brave/BraveBackend.scala @@ -0,0 +1,131 @@ +package com.softwaremill.sttp.brave + +import brave.Span +import brave.http.{HttpClientAdapter, HttpClientHandler, HttpTracing} +import brave.propagation.TraceContext +import com.softwaremill.sttp.brave.BraveBackend._ +import com.softwaremill.sttp.{FollowRedirectsBackend, MonadError, Request, Response, SttpBackend} +import zipkin2.Endpoint + +import scala.language.higherKinds + +class BraveBackend[R[_], S] private (delegate: SttpBackend[R, S], + httpTracing: HttpTracing) + extends SttpBackend[R, S] { + + // .asInstanceOf as the create method lacks generics in its return type + private val handler = HttpClientHandler + .create(httpTracing, SttpHttpClientAdapter) + .asInstanceOf[HttpClientHandler[AnyRequest, AnyResponse]] + + private val tracer = httpTracing.tracing().tracer() + + override def send[T](request: Request[T, S]): R[Response[T]] = { + val span = createSpan(request) + val tracedRequest = injectTracing(span, request) + val startedSpan = + handler.handleSend(NoopInjector, tracedRequest, tracedRequest, span) + + sendAndHandleReceive(startedSpan, tracedRequest) + } + + override def close(): Unit = delegate.close() + + override def responseMonad: MonadError[R] = delegate.responseMonad + + private def createSpan(request: AnyRequest): Span = { + request + .tag(TraceContextRequestTag) + .map(_.asInstanceOf[TraceContext]) match { + case None => handler.nextSpan(request) + case Some(traceContext) => tracer.newChild(traceContext) + } + } + + private def sendAndHandleReceive[T]( + span: Span, + request: Request[T, S]): R[Response[T]] = { + val spanInScope = tracer.withSpanInScope(span) + + responseMonad.handleError( + responseMonad.map(delegate.send(request)) { response => + spanInScope.close() + handler.handleReceive(response, null, span) + response + } + ) { + case e: Exception => + spanInScope.close() + handler.handleReceive(null, e, span) + responseMonad.error(e) + } + } + + private def injectTracing[T](span: Span, + request: Request[T, S]): Request[T, S] = { + /* + Sadly the Brave API supports only mutable request representations, hence we need to work our way around + this and inject headers into the traced request with the help of a mutable variable. Later a no-op injector + is used (during the call to `handleSend`). + */ + + var tracedRequest: Request[T, S] = request + + httpTracing + .tracing() + .propagation() + .injector((_: Request[_, _], key: String, value: String) => { + tracedRequest = tracedRequest.header(key, value) + }) + .inject(span.context(), request) + + tracedRequest + } +} + +object BraveBackend { + private val NoopInjector = new TraceContext.Injector[Request[_, _]] { + override def inject(traceContext: TraceContext, + carrier: Request[_, _]): Unit = {} + } + + private val TraceContextRequestTag = classOf[TraceContext].getName + + implicit class RichRequest[T, S](request: Request[T, S]) { + def tagWithTraceContext(traceContext: TraceContext): Request[T, S] = + request.tag(TraceContextRequestTag, traceContext) + } + + type AnyRequest = Request[_, _] + type AnyResponse = Response[_] + + def apply[R[_], S](delegate: SttpBackend[R, S], + httpTracing: HttpTracing): SttpBackend[R, S] = { + // redirects should be handled before brave tracing, hence adding the follow-redirects backend on top + new FollowRedirectsBackend(new BraveBackend(delegate, httpTracing)) + } +} + +object SttpHttpClientAdapter + extends HttpClientAdapter[AnyRequest, AnyResponse] { + + override def method(request: AnyRequest): String = request.method.m + + override def url(request: AnyRequest): String = request.uri.toString + + override def requestHeader(request: AnyRequest, name: String): String = + request.headers.find(_._1.equalsIgnoreCase(name)).map(_._2).orNull + + override def statusCode(response: AnyResponse): Integer = response.code + + override def parseServerAddress(req: AnyRequest, + builder: Endpoint.Builder): Boolean = { + + if (builder.parseIp(req.uri.host)) { + req.uri.port.foreach(builder.port(_)) + true + } else { + false + } + } +} diff --git a/metrics/brave/src/test/scala/com/softwaremill/sttp/brave/BraveBackendTest.scala b/metrics/brave/src/test/scala/com/softwaremill/sttp/brave/BraveBackendTest.scala new file mode 100644 index 0000000..f046210 --- /dev/null +++ b/metrics/brave/src/test/scala/com/softwaremill/sttp/brave/BraveBackendTest.scala @@ -0,0 +1,144 @@ +package com.softwaremill.sttp.brave + +import brave.http.{HttpTracing, ITHttpClient} +import brave.internal.HexCodec +import com.softwaremill.sttp._ +import okhttp3.mockwebserver.MockResponse +import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} +import zipkin2.Span + +class BraveBackendTest extends FlatSpec with Matchers with BeforeAndAfter { + + // test proxy - contains the brave instrumentation tests + private var t: ITHttpClient[SttpBackend[Id, Nothing]] = null + + // we need to extract these protected ITHttpClient members to use in the custom test + private var _backend: SttpBackend[Id, Nothing] = null + private var _httpTracing: HttpTracing = null + private var _takeSpan: () => Span = null + + def newT(): Unit = { + t = new ITHttpClient[SttpBackend[Id, Nothing]]() { + override def post(client: SttpBackend[Id, Nothing], + pathIncludingQuery: String, + body: String): Unit = { + client.send(sttp.post(uri"${url(pathIncludingQuery)}").body(body)) + } + + override def get(client: SttpBackend[Id, Nothing], + pathIncludingQuery: String): Unit = { + client.send(sttp.get(uri"${url(pathIncludingQuery)}")) + } + + override def closeClient(client: SttpBackend[Id, Nothing]): Unit = + client.close() + + override def newClient(port: Int): SttpBackend[Id, Nothing] = { + _backend = + BraveBackend[Id, Nothing](HttpURLConnectionBackend(), httpTracing) + _httpTracing = httpTracing + _takeSpan = () => takeSpan() + + _backend + } + } + } + + before { + newT() + t.setup() + } + + after { + t.close() + t.server.shutdown() + } + + it should "propagatesSpan" in { + t.propagatesSpan() + } + + it should "makesChildOfCurrentSpan" in { + t.makesChildOfCurrentSpan() + } + + it should "propagatesExtra_newTrace" in { + t.propagatesExtra_newTrace() + } + + it should "propagatesExtra_unsampledTrace" in { + t.propagatesExtra_unsampledTrace() + } + + it should "propagates_sampledFalse" in { + t.propagates_sampledFalse() + } + + it should "customSampler" in { + t.customSampler() + } + + it should "reportsClientKindToZipkin" in { + t.reportsClientKindToZipkin() + } + + it should "reportsServerAddress" in { + t.reportsServerAddress() + } + + it should "defaultSpanNameIsMethodName" in { + t.defaultSpanNameIsMethodName() + } + + it should "supportsPortableCustomization" in { + t.supportsPortableCustomization() + } + + it should "addsStatusCodeWhenNotOk" in { + t.addsStatusCodeWhenNotOk() + } + + it should "redirect" in { + t.redirect() + } + + it should "post" in { + t.post() + } + +// these tests take a very long time to complete, but pass last time I checked + +// it should "reportsSpanOnTransportException" in { +// t.reportsSpanOnTransportException() +// } + +// it should "addsErrorTagOnTransportException" in { +// t.addsErrorTagOnTransportException() +// } + + it should "httpPathTagExcludesQueryParams" in { + t.httpPathTagExcludesQueryParams() + } + + it should "use the tracing context from tags if available" in { + val tracer = _httpTracing.tracing.tracer + t.server.enqueue(new MockResponse) + + val parent = tracer.newTrace.name("test").start + try { + import com.softwaremill.sttp.brave.BraveBackend._ + _backend.send( + sttp + .get(uri"http://127.0.0.1:${t.server.getPort}/foo") + .tagWithTraceContext(parent.context())) + } finally parent.finish() + + val request = t.server.takeRequest + request.getHeader("x-b3-traceId") should be(parent.context.traceIdString) + request.getHeader("x-b3-parentspanid") should be( + HexCodec.toLowerHex(parent.context.spanId)) + + Set(_takeSpan(), _takeSpan()).map(_.kind) should be( + Set(null, Span.Kind.CLIENT)) + } +} |