diff options
8 files changed, 104 insertions, 17 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala index 2f7184a..5691d3c 100644 --- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala +++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala @@ -33,7 +33,7 @@ class AkkaHttpHandler private (actorSystem: ActorSystem, private implicit val as = actorSystem private implicit val materializer = ActorMaterializer() - override def send[T](r: Request[T, S]): Future[Response[T]] = { + override protected def doSend[T](r: Request[T, S]): Future[Response[T]] = { implicit val ec = this.ec requestToAkka(r) .flatMap(setBodyOnAkka(r, r.body, _)) diff --git a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala index 9c4704b..2e5d16b 100644 --- a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala +++ b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala @@ -36,7 +36,7 @@ abstract class AsyncHttpClientHandler[R[_], S](asyncHttpClient: AsyncHttpClient, closeClient: Boolean) extends SttpHandler[R, S] { - override def send[T](r: Request[T, S]): R[Response[T]] = { + override protected def doSend[T](r: Request[T, S]): R[Response[T]] = { val preparedRequest = asyncHttpClient .prepareRequest(requestToAsync(r)) diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala index 76d897b..548dd9b 100644 --- a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala @@ -13,13 +13,16 @@ import scala.io.Source import scala.collection.JavaConverters._ object HttpURLConnectionHandler extends SttpHandler[Id, Nothing] { - override def send[T](r: Request[T, Nothing]): Response[T] = { + override protected def doSend[T](r: Request[T, Nothing]): Response[T] = { val c = new URL(r.uri.toString).openConnection().asInstanceOf[HttpURLConnection] c.setRequestMethod(r.method.m) r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) } c.setDoInput(true) + // redirects are handled in SttpHandler + c.setInstanceFollowRedirects(false) + if (r.body != NoBody) { c.setDoOutput(true) // we need to take care to: diff --git a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala index 5e4d33e..95dc56b 100644 --- a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala +++ b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala @@ -31,7 +31,8 @@ case class RequestT[U[_], T, +S]( uri: U[Uri], body: RequestBody[S], headers: Seq[(String, String)], - response: ResponseAs[T, S] + response: ResponseAs[T, S], + options: RequestOptions ) { def get(uri: Uri): Request[T, S] = this.copy[Id, T, S](uri = uri, method = Method.GET) @@ -218,6 +219,9 @@ case class RequestT[U[_], T, +S]( def mapResponse[T2](f: T => T2): RequestT[U, T2, S] = this.copy(response = response.map(f)) + def followRedirects(fr: Boolean): RequestT[U, T, S] = + this.copy(options = options.copy(followRedirects = fr)) + def send[R[_]]()(implicit handler: SttpHandler[R, S], isIdInRequest: IsIdInRequest[U]): R[Response[T]] = { // we could avoid the asInstanceOf by creating an artificial copy @@ -268,3 +272,5 @@ class SpecifyAuthScheme[U[_], T, +S](hn: String, rt: RequestT[U, T, S]) { def bearer(token: String): RequestT[U, T, S] = rt.header(hn, s"Bearer $token") } + +case class RequestOptions(followRedirects: Boolean) diff --git a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala index c6df151..fd836bd 100644 --- a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala @@ -1,5 +1,7 @@ package com.softwaremill.sttp +import java.net.URI + import scala.language.higherKinds /** @@ -9,9 +11,41 @@ import scala.language.higherKinds * if streaming requests/responses is not supported by this handler. */ trait SttpHandler[R[_], -S] { - def send[T](request: Request[T, S]): R[Response[T]] + def send[T](request: Request[T, S]): R[Response[T]] = { + val resp = doSend(request) + if (request.options.followRedirects) { + responseMonad.flatMap(resp, { response: Response[T] => + if (response.isRedirect) { + followRedirect(request, response) + } else { + responseMonad.unit(response) + } + }) + } else { + resp + } + } + + private def followRedirect[T](request: Request[T, S], + response: Response[T]): R[Response[T]] = { + def isRelative(uri: String) = !uri.contains("://") + + response.header(LocationHeader).fold(responseMonad.unit(response)) { loc => + val uri = if (isRelative(loc)) { + // using java's URI to resolve a relative URI + uri"${new URI(request.uri.toString).resolve(loc).toString}" + } else { + uri"$loc" + } + + send(request.copy[Id, T, S](uri = uri)) + } + } + def close(): Unit = {} + protected def doSend[T](request: Request[T, S]): R[Response[T]] + /** * The monad in which the responses are wrapped. Allows writing wrapper * handlers, which map/flatMap over the return value of [[send]]. diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index a0a6c57..3c4e844 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -37,6 +37,7 @@ package object sttp { private[sttp] val AcceptEncodingHeader = "Accept-Encoding" private[sttp] val ContentEncodingHeader = "Content-Encoding" private[sttp] val ContentDispositionHeader = "Content-Disposition" + private[sttp] val LocationHeader = "Location" private[sttp] val Utf8 = "utf-8" private[sttp] val Iso88591 = "iso-8859-1" private[sttp] val CrLf = "\r\n" @@ -54,7 +55,12 @@ package object sttp { * An empty request with no headers. */ val emptyRequest: RequestT[Empty, String, Nothing] = - RequestT[Empty, String, Nothing](None, None, NoBody, Vector(), asString) + RequestT[Empty, String, Nothing](None, + None, + NoBody, + Vector(), + asString, + RequestOptions(followRedirects = true)) /** * A starting request, with the following modifications comparing to diff --git a/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala index 9e7669f..b95bf8a 100644 --- a/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala +++ b/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala @@ -130,7 +130,7 @@ abstract class OkHttpHandler[R[_], S](client: OkHttpClient) class OkHttpSyncHandler private (client: OkHttpClient) extends OkHttpHandler[Id, Nothing](client) { - override def send[T](r: Request[T, Nothing]): Response[T] = { + override protected def doSend[T](r: Request[T, Nothing]): Response[T] = { val request = convertRequest(r) val response = client.newCall(request).execute() readResponse(response, r.response) @@ -148,7 +148,7 @@ object OkHttpSyncHandler { abstract class OkHttpAsyncHandler[R[_], S](client: OkHttpClient, rm: MonadAsyncError[R]) extends OkHttpHandler[R, S](client) { - override def send[T](r: Request[T, S]): R[Response[T]] = { + override protected def doSend[T](r: Request[T, S]): R[Response[T]] = { val request = convertRequest(r) rm.flatten(rm.async[R[Response[T]]] { cb => diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 2275c5e..61b9c19 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -8,7 +8,7 @@ import java.time.{ZoneId, ZonedDateTime} import akka.http.scaladsl.coding.{Deflate, Gzip, NoCoding} import akka.http.scaladsl.model.headers.CacheDirectives._ import akka.http.scaladsl.model.headers._ -import akka.http.scaladsl.model.{DateTime, FormData} +import akka.http.scaladsl.model.{DateTime, FormData, StatusCodes} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.Route import akka.http.scaladsl.server.directives.Credentials @@ -67,19 +67,17 @@ class BasicTests } } ~ get { parameterMap { params => - complete( - List("GET", "/echo", paramsToString(params)) - .filter(_.nonEmpty) - .mkString(" ")) + complete(List("GET", "/echo", paramsToString(params)) + .filter(_.nonEmpty) + .mkString(" ")) } } ~ post { parameterMap { params => entity(as[String]) { body: String => - complete( - List("POST", "/echo", paramsToString(params), body) - .filter(_.nonEmpty) - .mkString(" ")) + complete(List("POST", "/echo", paramsToString(params), body) + .filter(_.nonEmpty) + .mkString(" ")) } } } @@ -149,6 +147,16 @@ class BasicTests .map(v => v.mkString(", ")) } } + } ~ pathPrefix("redirect") { + path("r1") { + redirect("/redirect/r2", StatusCodes.TemporaryRedirect) + } ~ + path("r2") { + redirect("/redirect/r3", StatusCodes.PermanentRedirect) + } ~ + path("r3") { + complete("ok") + } } override def port = 51823 @@ -196,6 +204,7 @@ class BasicTests compressionTests() downloadFileTests() multipartTests() + redirectTests() def parseResponseTests(): Unit = { name should "parse response as string" in { @@ -554,6 +563,35 @@ class BasicTests } finally f.delete() } } + + def redirectTests(): Unit = { + val r1 = sttp.post(uri"$endpoint/redirect/r1") + val r2 = sttp.post(uri"$endpoint/redirect/r2") + + name should "not redirect when redirects shouldn't be followed (temporary)" in { + val resp = r1.followRedirects(false).send().force() + resp.code should be(307) + resp.body should not be ("ok") + } + + name should "not redirect when redirects shouldn't be followed (permanent)" in { + val resp = r2.followRedirects(false).send().force() + resp.code should be(308) + resp.body should not be ("ok") + } + + name should "redirect when redirects should be followed" in { + val resp = r2.send().force() + resp.code should be(200) + resp.body should be("ok") + } + + name should "redirect twice when redirects should be followed" in { + val resp = r1.send().force() + resp.code should be(200) + resp.body should be("ok") + } + } } override protected def afterAll(): Unit = { |