aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/com/softwaremill/sttp/FollowRedirectsBackend.scala
blob: 7004631981e4580cefc48d0978c087d35049a48e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.softwaremill.sttp

import java.net.URI

import scala.language.higherKinds

class FollowRedirectsBackend[R[_], S](delegate: SttpBackend[R, S]) extends SttpBackend[R, S] {

  def send[T](request: Request[T, S]): R[Response[T]] = {
    sendWithCounter(request, 0)
  }

  private def sendWithCounter[T](request: Request[T, S], redirects: Int): R[Response[T]] = {
    // if there are nested follow redirect backends, disabling them and handling redirects here
    val resp = delegate.send(request.followRedirects(false))
    if (request.options.followRedirects) {
      responseMonad.flatMap(resp) { response: Response[T] =>
        if (response.isRedirect) {
          followRedirect(request, response, redirects)
        } else {
          responseMonad.unit(response)
        }
      }
    } else {
      resp
    }
  }

  private def followRedirect[T](request: Request[T, S], response: Response[T], redirects: Int): R[Response[T]] = {

    response.header(LocationHeader).fold(responseMonad.unit(response)) { loc =>
      if (redirects >= request.options.maxRedirects) {
        responseMonad.unit(Response(Left("Too many redirects"), 0, "", Nil, Nil))
      } else {
        followRedirect(request, response, redirects, loc)
      }
    }
  }

  private def followRedirect[T](request: Request[T, S],
                                response: Response[T],
                                redirects: Int,
                                loc: String): R[Response[T]] = {

    def isRelative(uri: String) = !uri.contains("://")

    val uri = if (isRelative(loc)) {
      // using java's URI to resolve a relative URI
      uri"${new URI(request.uri.toString).resolve(loc).toString}"
    } else {
      uri"$loc"
    }

    val redirectResponse =
      sendWithCounter(request.copy[Id, T, S](uri = uri), redirects + 1)

    responseMonad.map(redirectResponse) { rr =>
      val responseNoBody =
        response.copy(body = response.body.right.map(_ => ()))
      rr.copy(history = responseNoBody :: rr.history)
    }
  }

  override def close(): Unit = delegate.close()

  override def responseMonad: MonadError[R] = delegate.responseMonad
}

object FollowRedirectsBackend {
  private[sttp] val MaxRedirects = 32
}