diff options
author | adamw <adam@warski.org> | 2017-07-15 10:58:24 +0200 |
---|---|---|
committer | adamw <adam@warski.org> | 2017-07-15 10:58:24 +0200 |
commit | bc685df2cd50814b45e669f4f602732887c2879c (patch) | |
tree | 4df984768865b86b48fbaae7c21f22fd0c3ea079 | |
parent | fdc9b3f9420165cc65c8dd9fe20057a4a12e69c6 (diff) | |
download | sttp-bc685df2cd50814b45e669f4f602732887c2879c.tar.gz sttp-bc685df2cd50814b45e669f4f602732887c2879c.tar.bz2 sttp-bc685df2cd50814b45e669f4f602732887c2879c.zip |
Headers & errors support
6 files changed, 102 insertions, 27 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala index fc2d632..9125ca3 100644 --- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala +++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala @@ -14,6 +14,7 @@ import com.softwaremill.sttp.model._ import scala.concurrent.Future import scala.util.{Failure, Success, Try} +import scala.collection.immutable.Seq class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, Source[ByteString, Any]] { @@ -32,7 +33,8 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) .flatMap(Http().singleRequest(_)) .flatMap { hr => val code = hr.status.intValue() - bodyFromAkka(responseAs, hr).map(Response(code, _)) + bodyFromAkka(responseAs, hr).map( + Response(_, code, headersFromAkka(hr))) } } @@ -72,6 +74,14 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) } } + private def headersFromAkka(hr: HttpResponse): Seq[(String, String)] = { + val ch = ContentTypeHeader -> hr.entity.contentType.toString() + val cl = + hr.entity.contentLengthOption.map(ContentLengthHeader -> _.toString) + val other = hr.headers.map(h => (h.name, h.value)) + ch :: (cl.toList ++ other) + } + private def requestToAkka(r: Request): Future[HttpRequest] = { val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method)) val parsed = @@ -43,7 +43,7 @@ lazy val commonSettings = Seq( val akkaHttpVersion = "10.0.9" val akkaHttp = "com.typesafe.akka" %% "akka-http" % akkaHttpVersion -val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" % "test" +val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" lazy val rootProject = (project in file(".")) .settings(commonSettings: _*) @@ -56,7 +56,7 @@ lazy val core: Project = (project in file("core")) name := "core", libraryDependencies ++= Seq( "org.scalacheck" %% "scalacheck" % "1.13.5" % "test", - scalaTest + scalaTest % "test" ) ) @@ -76,7 +76,8 @@ lazy val tests: Project = (project in file("tests")) libraryDependencies ++= Seq( akkaHttp, scalaTest, - "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test", - "com.github.pathikrit" %% "better-files" % "3.0.0" - ) + "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0", + "com.github.pathikrit" %% "better-files" % "3.0.0", + "ch.qos.logback" % "logback-core" % "1.2.3" + ).map(_ % "test") ) dependsOn (core, akkaHttpHandler) diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala index fdbb322..07025d5 100644 --- a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala @@ -1,11 +1,6 @@ package com.softwaremill.sttp -import java.io.{ - ByteArrayOutputStream, - InputStream, - OutputStream, - OutputStreamWriter -} +import java.io._ import java.net.HttpURLConnection import java.nio.channels.Channels import java.nio.file.Files @@ -14,6 +9,7 @@ import com.softwaremill.sttp.model._ import scala.annotation.tailrec import scala.io.Source +import scala.collection.JavaConverters._ object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] { override def send[T](r: Request, @@ -24,8 +20,13 @@ object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] { c.setDoInput(true) setBody(r.body, c) - val status = c.getResponseCode - Response(status, readResponse(c.getInputStream, responseAs)) + try { + val is = c.getInputStream + readResponse(c, is, responseAs) + } catch { + case _: IOException if c.getResponseCode != -1 => + readResponse(c, c.getErrorStream, responseAs) + } } private def setBody(body: RequestBody, c: HttpURLConnection): Unit = { @@ -76,8 +77,19 @@ object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] { } } - private def readResponse[T](is: InputStream, - responseAs: ResponseAs[T, Nothing]): T = + private def readResponse[T]( + c: HttpURLConnection, + is: InputStream, + responseAs: ResponseAs[T, Nothing]): Response[T] = { + + val headers = c.getHeaderFields.asScala.toVector + .filter(_._1 != null) + .flatMap { case (k, vv) => vv.asScala.map((k, _)) } + Response(readResponseBody(is, responseAs), c.getResponseCode, headers) + } + + private def readResponseBody[T](is: InputStream, + responseAs: ResponseAs[T, Nothing]): T = responseAs match { case IgnoreResponse => @tailrec def consume(): Unit = if (is.read() != -1) consume() diff --git a/core/src/main/scala/com/softwaremill/sttp/Response.scala b/core/src/main/scala/com/softwaremill/sttp/Response.scala index 4b90259..dbaf71f 100644 --- a/core/src/main/scala/com/softwaremill/sttp/Response.scala +++ b/core/src/main/scala/com/softwaremill/sttp/Response.scala @@ -1,3 +1,21 @@ package com.softwaremill.sttp -case class Response[T](status: Int, body: T) +import scala.collection.immutable.Seq +import scala.util.Try + +case class Response[T](body: T, code: Int, headers: Seq[(String, String)]) { + def is200: Boolean = code == 200 + def isSuccess: Boolean = code >= 200 && code < 300 + def isRedirect: Boolean = code >= 300 && code < 400 + def isClientError: Boolean = code >= 400 && code < 500 + def isServerError: Boolean = code >= 500 && code < 600 + + def header(h: String): Option[String] = + headers.find(_._1.equalsIgnoreCase(h)).map(_._2) + def headers(h: String): Seq[String] = + headers.filter(_._1.equalsIgnoreCase(h)).map(_._2) + + def contentType: Option[String] = header(ContentTypeHeader) + def contentLength: Option[Long] = + header(ContentLengthHeader).flatMap(cl => Try(cl.toLong).toOption) +} diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index 79d9a17..c39f256 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -130,10 +130,9 @@ package object sttp { def header(k: String, v: String, replaceExisting: Boolean = false): RequestTemplate[U] = { - val kLower = k.toLowerCase val current = if (replaceExisting) - headers.filterNot(_._1.toLowerCase.contains(kLower)) + headers.filterNot(_._1.equalsIgnoreCase(k)) else headers this.copy(headers = current :+ (k -> v)) } @@ -232,7 +231,8 @@ package object sttp { val sttp: RequestTemplate[Empty] = RequestTemplate.empty - private val ContentTypeHeader = "content-type" + private[sttp] val ContentTypeHeader = "content-type" + private[sttp] val ContentLengthHeader = "content-length" private val Utf8 = "utf-8" private val ApplicationOctetStreamContentType = "application/octet-stream" diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 39a4b25..86d28a9 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -2,11 +2,12 @@ package com.softwaremill.sttp import java.io.ByteArrayInputStream import java.net.URI -import java.nio.ByteBuffer import akka.stream.ActorMaterializer import akka.actor.ActorSystem import akka.http.scaladsl.Http +import akka.http.scaladsl.model.headers._ +import akka.http.scaladsl.model.headers.CacheDirectives._ import akka.http.scaladsl.server.Directives._ import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler import com.typesafe.scalalogging.StrictLogging @@ -47,6 +48,14 @@ class BasicTests } } } + } ~ path("set_headers") { + get { + respondWithHeader(`Cache-Control`(`max-age`(1000L))) { + respondWithHeader(`Cache-Control`(`no-cache`)) { + complete("ok") + } + } + } } private implicit val actorSystem: ActorSystem = ActorSystem("sttp-test") @@ -91,6 +100,8 @@ class BasicTests parseResponseTests() parameterTests() bodyTests() + headerTests() + errorsTests() def parseResponseTests(): Unit = { name should "parse response as string" in { @@ -138,12 +149,7 @@ class BasicTests fc should be(expectedPostEchoResponse) } - name should "post a byte buffer" in { - val response = - postEcho.body(ByteBuffer.wrap(testBodyBytes)).send(responseAsString) - val fc = forceResponse.force(response).body - fc should be(expectedPostEchoResponse) - } + name should "post a byte buffer" in {} name should "post a file" in { val f = File.newTemporaryFile().write(testBody) @@ -163,5 +169,33 @@ class BasicTests } finally f.delete() } } + + def headerTests(): Unit = { + val getHeaders = sttp.get(new URI(endpoint + "/set_headers")) + + name should "read response headers" in { + val wrappedResponse = getHeaders.send(ignoreResponse) + val response = forceResponse.force(wrappedResponse) + response.headers should have length (6) + response.headers("Cache-Control").toSet should be( + Set("no-cache", "max-age=1000")) + response.header("Server") should be('defined) + response.header("server") should be('defined) + response.header("Server").get should startWith("akka-http") + response.contentType should be(Some("text/plain; charset=UTF-8")) + response.contentLength should be(Some(2L)) + } + } + + def errorsTests(): Unit = { + val getHeaders = sttp.post(new URI(endpoint + "/set_headers")) + + name should "return 405 when method not allowed" in { + val response = getHeaders.send(ignoreResponse) + val resp = forceResponse.force(response) + resp.code should be(405) + resp.isClientError should be(true) + } + } } } |