aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala2
-rw-r--r--async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala2
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala5
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/RequestT.scala8
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala36
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/package.scala8
-rw-r--r--okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala4
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala56
8 files changed, 104 insertions, 17 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala
index 2f7184a..5691d3c 100644
--- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala
+++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpHandler.scala
@@ -33,7 +33,7 @@ class AkkaHttpHandler private (actorSystem: ActorSystem,
private implicit val as = actorSystem
private implicit val materializer = ActorMaterializer()
- override def send[T](r: Request[T, S]): Future[Response[T]] = {
+ override protected def doSend[T](r: Request[T, S]): Future[Response[T]] = {
implicit val ec = this.ec
requestToAkka(r)
.flatMap(setBodyOnAkka(r, r.body, _))
diff --git a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
index 9c4704b..2e5d16b 100644
--- a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
+++ b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
@@ -36,7 +36,7 @@ abstract class AsyncHttpClientHandler[R[_], S](asyncHttpClient: AsyncHttpClient,
closeClient: Boolean)
extends SttpHandler[R, S] {
- override def send[T](r: Request[T, S]): R[Response[T]] = {
+ override protected def doSend[T](r: Request[T, S]): R[Response[T]] = {
val preparedRequest = asyncHttpClient
.prepareRequest(requestToAsync(r))
diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala
index 76d897b..548dd9b 100644
--- a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionHandler.scala
@@ -13,13 +13,16 @@ import scala.io.Source
import scala.collection.JavaConverters._
object HttpURLConnectionHandler extends SttpHandler[Id, Nothing] {
- override def send[T](r: Request[T, Nothing]): Response[T] = {
+ override protected def doSend[T](r: Request[T, Nothing]): Response[T] = {
val c =
new URL(r.uri.toString).openConnection().asInstanceOf[HttpURLConnection]
c.setRequestMethod(r.method.m)
r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) }
c.setDoInput(true)
+ // redirects are handled in SttpHandler
+ c.setInstanceFollowRedirects(false)
+
if (r.body != NoBody) {
c.setDoOutput(true)
// we need to take care to:
diff --git a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
index 5e4d33e..95dc56b 100644
--- a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
@@ -31,7 +31,8 @@ case class RequestT[U[_], T, +S](
uri: U[Uri],
body: RequestBody[S],
headers: Seq[(String, String)],
- response: ResponseAs[T, S]
+ response: ResponseAs[T, S],
+ options: RequestOptions
) {
def get(uri: Uri): Request[T, S] =
this.copy[Id, T, S](uri = uri, method = Method.GET)
@@ -218,6 +219,9 @@ case class RequestT[U[_], T, +S](
def mapResponse[T2](f: T => T2): RequestT[U, T2, S] =
this.copy(response = response.map(f))
+ def followRedirects(fr: Boolean): RequestT[U, T, S] =
+ this.copy(options = options.copy(followRedirects = fr))
+
def send[R[_]]()(implicit handler: SttpHandler[R, S],
isIdInRequest: IsIdInRequest[U]): R[Response[T]] = {
// we could avoid the asInstanceOf by creating an artificial copy
@@ -268,3 +272,5 @@ class SpecifyAuthScheme[U[_], T, +S](hn: String, rt: RequestT[U, T, S]) {
def bearer(token: String): RequestT[U, T, S] =
rt.header(hn, s"Bearer $token")
}
+
+case class RequestOptions(followRedirects: Boolean)
diff --git a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
index c6df151..fd836bd 100644
--- a/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/SttpHandler.scala
@@ -1,5 +1,7 @@
package com.softwaremill.sttp
+import java.net.URI
+
import scala.language.higherKinds
/**
@@ -9,9 +11,41 @@ import scala.language.higherKinds
* if streaming requests/responses is not supported by this handler.
*/
trait SttpHandler[R[_], -S] {
- def send[T](request: Request[T, S]): R[Response[T]]
+ def send[T](request: Request[T, S]): R[Response[T]] = {
+ val resp = doSend(request)
+ if (request.options.followRedirects) {
+ responseMonad.flatMap(resp, { response: Response[T] =>
+ if (response.isRedirect) {
+ followRedirect(request, response)
+ } else {
+ responseMonad.unit(response)
+ }
+ })
+ } else {
+ resp
+ }
+ }
+
+ private def followRedirect[T](request: Request[T, S],
+ response: Response[T]): R[Response[T]] = {
+ def isRelative(uri: String) = !uri.contains("://")
+
+ response.header(LocationHeader).fold(responseMonad.unit(response)) { loc =>
+ val uri = if (isRelative(loc)) {
+ // using java's URI to resolve a relative URI
+ uri"${new URI(request.uri.toString).resolve(loc).toString}"
+ } else {
+ uri"$loc"
+ }
+
+ send(request.copy[Id, T, S](uri = uri))
+ }
+ }
+
def close(): Unit = {}
+ protected def doSend[T](request: Request[T, S]): R[Response[T]]
+
/**
* The monad in which the responses are wrapped. Allows writing wrapper
* handlers, which map/flatMap over the return value of [[send]].
diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala
index a0a6c57..3c4e844 100644
--- a/core/src/main/scala/com/softwaremill/sttp/package.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/package.scala
@@ -37,6 +37,7 @@ package object sttp {
private[sttp] val AcceptEncodingHeader = "Accept-Encoding"
private[sttp] val ContentEncodingHeader = "Content-Encoding"
private[sttp] val ContentDispositionHeader = "Content-Disposition"
+ private[sttp] val LocationHeader = "Location"
private[sttp] val Utf8 = "utf-8"
private[sttp] val Iso88591 = "iso-8859-1"
private[sttp] val CrLf = "\r\n"
@@ -54,7 +55,12 @@ package object sttp {
* An empty request with no headers.
*/
val emptyRequest: RequestT[Empty, String, Nothing] =
- RequestT[Empty, String, Nothing](None, None, NoBody, Vector(), asString)
+ RequestT[Empty, String, Nothing](None,
+ None,
+ NoBody,
+ Vector(),
+ asString,
+ RequestOptions(followRedirects = true))
/**
* A starting request, with the following modifications comparing to
diff --git a/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
index 9e7669f..b95bf8a 100644
--- a/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
+++ b/okhttp-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
@@ -130,7 +130,7 @@ abstract class OkHttpHandler[R[_], S](client: OkHttpClient)
class OkHttpSyncHandler private (client: OkHttpClient)
extends OkHttpHandler[Id, Nothing](client) {
- override def send[T](r: Request[T, Nothing]): Response[T] = {
+ override protected def doSend[T](r: Request[T, Nothing]): Response[T] = {
val request = convertRequest(r)
val response = client.newCall(request).execute()
readResponse(response, r.response)
@@ -148,7 +148,7 @@ object OkHttpSyncHandler {
abstract class OkHttpAsyncHandler[R[_], S](client: OkHttpClient,
rm: MonadAsyncError[R])
extends OkHttpHandler[R, S](client) {
- override def send[T](r: Request[T, S]): R[Response[T]] = {
+ override protected def doSend[T](r: Request[T, S]): R[Response[T]] = {
val request = convertRequest(r)
rm.flatten(rm.async[R[Response[T]]] { cb =>
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index 2275c5e..61b9c19 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -8,7 +8,7 @@ import java.time.{ZoneId, ZonedDateTime}
import akka.http.scaladsl.coding.{Deflate, Gzip, NoCoding}
import akka.http.scaladsl.model.headers.CacheDirectives._
import akka.http.scaladsl.model.headers._
-import akka.http.scaladsl.model.{DateTime, FormData}
+import akka.http.scaladsl.model.{DateTime, FormData, StatusCodes}
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.server.directives.Credentials
@@ -67,19 +67,17 @@ class BasicTests
}
} ~ get {
parameterMap { params =>
- complete(
- List("GET", "/echo", paramsToString(params))
- .filter(_.nonEmpty)
- .mkString(" "))
+ complete(List("GET", "/echo", paramsToString(params))
+ .filter(_.nonEmpty)
+ .mkString(" "))
}
} ~
post {
parameterMap { params =>
entity(as[String]) { body: String =>
- complete(
- List("POST", "/echo", paramsToString(params), body)
- .filter(_.nonEmpty)
- .mkString(" "))
+ complete(List("POST", "/echo", paramsToString(params), body)
+ .filter(_.nonEmpty)
+ .mkString(" "))
}
}
}
@@ -149,6 +147,16 @@ class BasicTests
.map(v => v.mkString(", "))
}
}
+ } ~ pathPrefix("redirect") {
+ path("r1") {
+ redirect("/redirect/r2", StatusCodes.TemporaryRedirect)
+ } ~
+ path("r2") {
+ redirect("/redirect/r3", StatusCodes.PermanentRedirect)
+ } ~
+ path("r3") {
+ complete("ok")
+ }
}
override def port = 51823
@@ -196,6 +204,7 @@ class BasicTests
compressionTests()
downloadFileTests()
multipartTests()
+ redirectTests()
def parseResponseTests(): Unit = {
name should "parse response as string" in {
@@ -554,6 +563,35 @@ class BasicTests
} finally f.delete()
}
}
+
+ def redirectTests(): Unit = {
+ val r1 = sttp.post(uri"$endpoint/redirect/r1")
+ val r2 = sttp.post(uri"$endpoint/redirect/r2")
+
+ name should "not redirect when redirects shouldn't be followed (temporary)" in {
+ val resp = r1.followRedirects(false).send().force()
+ resp.code should be(307)
+ resp.body should not be ("ok")
+ }
+
+ name should "not redirect when redirects shouldn't be followed (permanent)" in {
+ val resp = r2.followRedirects(false).send().force()
+ resp.code should be(308)
+ resp.body should not be ("ok")
+ }
+
+ name should "redirect when redirects should be followed" in {
+ val resp = r2.send().force()
+ resp.code should be(200)
+ resp.body should be("ok")
+ }
+
+ name should "redirect twice when redirects should be followed" in {
+ val resp = r1.send().force()
+ resp.code should be(200)
+ resp.body should be("ok")
+ }
+ }
}
override protected def afterAll(): Unit = {