aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala50
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala11
2 files changed, 48 insertions, 13 deletions
diff --git a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
index 2f44840..f52eef6 100644
--- a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
@@ -12,11 +12,16 @@ import scala.language.higherKinds
*/
trait SttpHandler[R[_], -S] {
def send[T](request: Request[T, S]): R[Response[T]] = {
+ sendWithCounter(request, 0)
+ }
+
+ private def sendWithCounter[T](request: Request[T, S],
+ redirects: Int): R[Response[T]] = {
val resp = doSend(request)
if (request.options.followRedirects) {
responseMonad.flatMap(resp) { response: Response[T] =>
if (response.isRedirect) {
- followRedirect(request, response)
+ followRedirect(request, response, redirects)
} else {
responseMonad.unit(response)
}
@@ -27,24 +32,39 @@ trait SttpHandler[R[_], -S] {
}
private def followRedirect[T](request: Request[T, S],
- response: Response[T]): R[Response[T]] = {
- def isRelative(uri: String) = !uri.contains("://")
+ response: Response[T],
+ redirects: Int): R[Response[T]] = {
response.header(LocationHeader).fold(responseMonad.unit(response)) { loc =>
- val uri = if (isRelative(loc)) {
- // using java's URI to resolve a relative URI
- uri"${new URI(request.uri.toString).resolve(loc).toString}"
+ if (redirects >= SttpHandler.MaxRedirects) {
+ responseMonad.unit(Response(Left("Too many redirects"), 0, Nil, Nil))
} else {
- uri"$loc"
+ followRedirect(request, response, redirects, loc)
}
+ }
+ }
- val redirectResponse = send(request.copy[Id, T, S](uri = uri))
+ private def followRedirect[T](request: Request[T, S],
+ response: Response[T],
+ redirects: Int,
+ loc: String): R[Response[T]] = {
- responseMonad.map(redirectResponse) { rr =>
- val responseNoBody =
- response.copy(body = response.body.right.map(_ => ()))
- rr.copy(history = responseNoBody :: rr.history)
- }
+ def isRelative(uri: String) = !uri.contains("://")
+
+ val uri = if (isRelative(loc)) {
+ // using java's URI to resolve a relative URI
+ uri"${new URI(request.uri.toString).resolve(loc).toString}"
+ } else {
+ uri"$loc"
+ }
+
+ val redirectResponse =
+ sendWithCounter(request.copy[Id, T, S](uri = uri), redirects + 1)
+
+ responseMonad.map(redirectResponse) { rr =>
+ val responseNoBody =
+ response.copy(body = response.body.right.map(_ => ()))
+ rr.copy(history = responseNoBody :: rr.history)
}
}
@@ -58,3 +78,7 @@ trait SttpHandler[R[_], -S] {
*/
def responseMonad: MonadError[R]
}
+
+object SttpHandler {
+ private[sttp] val MaxRedirects = 32
+}
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index f2639b5..d43c43f 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -159,6 +159,9 @@ class BasicTests
} ~
path("r4") {
complete("819")
+ } ~
+ path("loop") {
+ redirect("/redirect/loop", StatusCodes.Found)
}
}
@@ -573,11 +576,13 @@ class BasicTests
val r2 = sttp.post(uri"$endpoint/redirect/r2")
val r3 = sttp.post(uri"$endpoint/redirect/r3")
val r4response = "819"
+ val loop = sttp.post(uri"$endpoint/redirect/loop")
name should "not redirect when redirects shouldn't be followed (temporary)" in {
val resp = r1.followRedirects(false).send().force()
resp.code should be(307)
resp.body should be('left)
+ resp.history should be('empty)
}
name should "not redirect when redirects shouldn't be followed (permanent)" in {
@@ -621,6 +626,12 @@ class BasicTests
resp.history(1).code should be(308)
resp.history(2).code should be(302)
}
+
+ name should "break redirect loops" in {
+ val resp = loop.send().force()
+ resp.code should be(0)
+ resp.history should have size (SttpHandler.MaxRedirects)
+ }
}
}