From 9406043b822a3aa9a7ef27630036a91cc3e7ed3a Mon Sep 17 00:00:00 2001 From: n4to4 Date: Mon, 23 Apr 2018 22:11:39 +0900 Subject: Make max redirects configurable --- .../scala/com/softwaremill/sttp/FollowRedirectsBackend.scala | 2 +- core/src/main/scala/com/softwaremill/sttp/RequestT.scala | 9 ++++++++- tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala | 7 +++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/softwaremill/sttp/FollowRedirectsBackend.scala b/core/src/main/scala/com/softwaremill/sttp/FollowRedirectsBackend.scala index 96c0213..7004631 100644 --- a/core/src/main/scala/com/softwaremill/sttp/FollowRedirectsBackend.scala +++ b/core/src/main/scala/com/softwaremill/sttp/FollowRedirectsBackend.scala @@ -29,7 +29,7 @@ class FollowRedirectsBackend[R[_], S](delegate: SttpBackend[R, S]) extends SttpB private def followRedirect[T](request: Request[T, S], response: Response[T], redirects: Int): R[Response[T]] = { response.header(LocationHeader).fold(responseMonad.unit(response)) { loc => - if (redirects >= FollowRedirectsBackend.MaxRedirects) { + if (redirects >= request.options.maxRedirects) { responseMonad.unit(Response(Left("Too many redirects"), 0, "", Nil, Nil)) } else { followRedirect(request, response, redirects, loc) diff --git a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala index 91635e1..7b3eab9 100644 --- a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala +++ b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala @@ -229,6 +229,9 @@ case class RequestT[U[_], T, +S]( def followRedirects(fr: Boolean): RequestT[U, T, S] = this.copy(options = options.copy(followRedirects = fr)) + def maxRedirects(n: Int): RequestT[U, T, S] = + this.copy(options = options.copy(maxRedirects = n)) + def tag(k: String, v: Any): RequestT[U, T, S] = this.copy(tags = tags + (k -> v)) @@ -281,4 +284,8 @@ class SpecifyAuthScheme[U[_], T, +S](hn: String, rt: RequestT[U, T, S]) { rt.header(hn, s"Bearer $token") } -case class RequestOptions(followRedirects: Boolean, readTimeout: Duration) +case class RequestOptions( + followRedirects: Boolean, + readTimeout: Duration, + maxRedirects: Int = FollowRedirectsBackend.MaxRedirects +) diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 1d03154..4138fe2 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -631,6 +631,13 @@ class BasicTests resp.code should be(0) resp.history should have size (FollowRedirectsBackend.MaxRedirects) } + + name should "break redirect loops after user-specified count" in { + val maxRedirects = 10 + val resp = loop.maxRedirects(maxRedirects).send().force() + resp.code should be(0) + resp.history should have size (maxRedirects) + } } def timeoutTests(): Unit = { -- cgit v1.2.3