From e60dab5c77bca2c3b7a76732e98b406a9d7a095e Mon Sep 17 00:00:00 2001 From: Paweł Stawicki Date: Wed, 24 Jan 2018 17:33:01 +0100 Subject: Allow SttpBackendStub to accept monad with response, not to hold thread when using e.g. Futures --- .../sttp/testing/SttpBackendStub.scala | 41 ++++++++++++++-------- .../sttp/testing/SttpBackendStubTests.scala | 28 +++++++++++++-- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/com/softwaremill/sttp/testing/SttpBackendStub.scala b/core/src/main/scala/com/softwaremill/sttp/testing/SttpBackendStub.scala index 9b4e93e..9e52505 100644 --- a/core/src/main/scala/com/softwaremill/sttp/testing/SttpBackendStub.scala +++ b/core/src/main/scala/com/softwaremill/sttp/testing/SttpBackendStub.scala @@ -28,7 +28,7 @@ import scala.util.{Failure, Success, Try} */ class SttpBackendStub[R[_], S] private ( rm: MonadError[R], - matchers: PartialFunction[Request[_, _], Response[_]], + matchers: PartialFunction[Request[_, _], R[Response[_]]], fallback: Option[SttpBackend[R, S]]) extends SttpBackend[R, S] { @@ -60,13 +60,14 @@ class SttpBackendStub[R[_], S] private ( def whenRequestMatchesPartial( partial: PartialFunction[Request[_, _], Response[_]]) : SttpBackendStub[R, S] = { - new SttpBackendStub(rm, matchers.orElse(partial), fallback) + val wrappedPartial = partial.andThen(rm.unit) + new SttpBackendStub(rm, matchers.orElse(wrappedPartial), fallback) } override def send[T](request: Request[T, S]): R[Response[T]] = { Try(matchers.lift(request)) match { - case Success(Some(response)) => - wrapResponse(tryAdjustResponseType(request.response, response)) + case Success(Some(responseMonad)) => + tryAdjustResponseType(rm, request.response, wrapResponse(responseMonad)) case Success(None) => fallback match { case None => @@ -85,6 +86,9 @@ class SttpBackendStub[R[_], S] private ( private def wrapResponse[T](r: Response[_]): R[Response[T]] = rm.unit(r.asInstanceOf[Response[T]]) + private def wrapResponse[T](r: R[Response[_]]): R[Response[T]] = + rm.map(r)(_.asInstanceOf[Response[T]]) + override def close(): Unit = {} override def responseMonad: MonadError[R] = rm @@ -99,12 +103,13 @@ class SttpBackendStub[R[_], S] private ( def thenRespondWithCode(code: Int, msg: String = ""): SttpBackendStub[R, S] = { val body = if (code >= 200 && code < 300) Right(msg) else Left(msg) - thenRespond(Response(body, code, msg, Nil, Nil)) + thenRespondWithMonad(rm.unit(Response(body, code, msg, Nil, Nil))) } def thenRespond[T](body: T): SttpBackendStub[R, S] = - thenRespond(Response[T](Right(body), 200, "OK", Nil, Nil)) - def thenRespond[T](resp: => Response[T]): SttpBackendStub[R, S] = { - val m: PartialFunction[Request[_, _], Response[_]] = { + thenRespondWithMonad( + rm.unit(Response[T](Right(body), 200, "OK", Nil, Nil))) + def thenRespondWithMonad(resp: => R[Response[_]]): SttpBackendStub[R, S] = { + val m: PartialFunction[Request[_, _], R[Response[_]]] = { case r if p(r) => resp } new SttpBackendStub(rm, matchers.orElse(m), fallback) @@ -159,13 +164,19 @@ object SttpBackendStub { PartialFunction.empty, Some(fallback)) - private[sttp] def tryAdjustResponseType[T, U](ra: ResponseAs[T, _], - r: Response[U]): Response[_] = { - r.body match { - case Left(_) => r - case Right(body) => - val newBody: Any = tryAdjustResponseBody(ra, body).getOrElse(body) - r.copy(body = Right(newBody)) + private[sttp] def tryAdjustResponseType[DesiredRType, RType, M[_]]( + rm: MonadError[M], + ra: ResponseAs[DesiredRType, _], + m: M[Response[RType]]): M[Response[DesiredRType]] = { + rm.map[Response[RType], Response[DesiredRType]](m) { r => + r.body match { + case Left(_) => r.asInstanceOf[Response[DesiredRType]] + case Right(body) => + val newBody: Any = tryAdjustResponseBody(ra, body).getOrElse(body) + r.copy( + body = + Right[String, DesiredRType](newBody.asInstanceOf[DesiredRType])) + } } } diff --git a/core/src/test/scala/com/softwaremill/sttp/testing/SttpBackendStubTests.scala b/core/src/test/scala/com/softwaremill/sttp/testing/SttpBackendStubTests.scala index f3d2cd4..22ceff3 100644 --- a/core/src/test/scala/com/softwaremill/sttp/testing/SttpBackendStubTests.scala +++ b/core/src/test/scala/com/softwaremill/sttp/testing/SttpBackendStubTests.scala @@ -8,6 +8,8 @@ import com.softwaremill.sttp._ import org.scalatest.concurrent.ScalaFutures import org.scalatest.{FlatSpec, Matchers} +import scala.concurrent.Future + class SttpBackendStubTests extends FlatSpec with Matchers with ScalaFutures { private val testingStub = SttpBackendStub(HttpURLConnectionBackend()) .whenRequestMatches(_.uri.path.startsWith(List("a", "b"))) @@ -77,7 +79,7 @@ class SttpBackendStubTests extends FlatSpec with Matchers with ScalaFutures { it should "handle exceptions thrown instead of a response (synchronous)" in { implicit val s = SttpBackendStub(HttpURLConnectionBackend()) .whenRequestMatches(_ => true) - .thenRespond(throw new TimeoutException()) + .thenRespondWithMonad(throw new TimeoutException()) a[TimeoutException] should be thrownBy { sttp.get(uri"http://example.org").send() @@ -87,7 +89,7 @@ class SttpBackendStubTests extends FlatSpec with Matchers with ScalaFutures { it should "handle exceptions thrown instead of a response (asynchronous)" in { implicit val s = SttpBackendStub(new FutureMonad()) .whenRequestMatches(_ => true) - .thenRespond(throw new TimeoutException()) + .thenRespondWithMonad(throw new TimeoutException()) val result = sttp.get(uri"http://example.org").send() result.failed.futureValue shouldBe a[TimeoutException] @@ -155,6 +157,28 @@ class SttpBackendStubTests extends FlatSpec with Matchers with ScalaFutures { } + it should "not hold the calling thread when passed a future monad" in { + val LongTimeMillis = 10000L + + val fm = new FutureMonad() + val f = Future { + Thread.sleep(LongTimeMillis) + Response(Right("OK"), 200, "", Nil, Nil) + } + + val before = System.currentTimeMillis() + implicit val s = SttpBackendStub(fm).whenAnyRequest + .thenRespondWithMonad(f) + + val result = sttp + .get(uri"http://example.org") + .send() + + val after = System.currentTimeMillis() + + (after - before) should be < LongTimeMillis + } + private val testingStubWithFallback = SttpBackendStub .withFallback(testingStub) .whenRequestMatches(_.uri.path.startsWith(List("c"))) -- cgit v1.2.3