aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoradamw <adam@warski.org>2017-07-15 10:58:24 +0200
committeradamw <adam@warski.org>2017-07-15 10:58:24 +0200
commitbc685df2cd50814b45e669f4f602732887c2879c (patch)
tree4df984768865b86b48fbaae7c21f22fd0c3ea079
parentfdc9b3f9420165cc65c8dd9fe20057a4a12e69c6 (diff)
downloadsttp-bc685df2cd50814b45e669f4f602732887c2879c.tar.gz
sttp-bc685df2cd50814b45e669f4f602732887c2879c.tar.bz2
sttp-bc685df2cd50814b45e669f4f602732887c2879c.zip
Headers & errors support
-rw-r--r--akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala12
-rw-r--r--build.sbt11
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala32
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/Response.scala20
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/package.scala6
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala48
6 files changed, 102 insertions, 27 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
index fc2d632..9125ca3 100644
--- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
+++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
@@ -14,6 +14,7 @@ import com.softwaremill.sttp.model._
import scala.concurrent.Future
import scala.util.{Failure, Success, Try}
+import scala.collection.immutable.Seq
class AkkaHttpSttpHandler(actorSystem: ActorSystem)
extends SttpHandler[Future, Source[ByteString, Any]] {
@@ -32,7 +33,8 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem)
.flatMap(Http().singleRequest(_))
.flatMap { hr =>
val code = hr.status.intValue()
- bodyFromAkka(responseAs, hr).map(Response(code, _))
+ bodyFromAkka(responseAs, hr).map(
+ Response(_, code, headersFromAkka(hr)))
}
}
@@ -72,6 +74,14 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem)
}
}
+ private def headersFromAkka(hr: HttpResponse): Seq[(String, String)] = {
+ val ch = ContentTypeHeader -> hr.entity.contentType.toString()
+ val cl =
+ hr.entity.contentLengthOption.map(ContentLengthHeader -> _.toString)
+ val other = hr.headers.map(h => (h.name, h.value))
+ ch :: (cl.toList ++ other)
+ }
+
private def requestToAkka(r: Request): Future[HttpRequest] = {
val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method))
val parsed =
diff --git a/build.sbt b/build.sbt
index 3aea1fe..d96c7d8 100644
--- a/build.sbt
+++ b/build.sbt
@@ -43,7 +43,7 @@ lazy val commonSettings = Seq(
val akkaHttpVersion = "10.0.9"
val akkaHttp = "com.typesafe.akka" %% "akka-http" % akkaHttpVersion
-val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" % "test"
+val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3"
lazy val rootProject = (project in file("."))
.settings(commonSettings: _*)
@@ -56,7 +56,7 @@ lazy val core: Project = (project in file("core"))
name := "core",
libraryDependencies ++= Seq(
"org.scalacheck" %% "scalacheck" % "1.13.5" % "test",
- scalaTest
+ scalaTest % "test"
)
)
@@ -76,7 +76,8 @@ lazy val tests: Project = (project in file("tests"))
libraryDependencies ++= Seq(
akkaHttp,
scalaTest,
- "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test",
- "com.github.pathikrit" %% "better-files" % "3.0.0"
- )
+ "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0",
+ "com.github.pathikrit" %% "better-files" % "3.0.0",
+ "ch.qos.logback" % "logback-core" % "1.2.3"
+ ).map(_ % "test")
) dependsOn (core, akkaHttpHandler)
diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala
index fdbb322..07025d5 100644
--- a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala
@@ -1,11 +1,6 @@
package com.softwaremill.sttp
-import java.io.{
- ByteArrayOutputStream,
- InputStream,
- OutputStream,
- OutputStreamWriter
-}
+import java.io._
import java.net.HttpURLConnection
import java.nio.channels.Channels
import java.nio.file.Files
@@ -14,6 +9,7 @@ import com.softwaremill.sttp.model._
import scala.annotation.tailrec
import scala.io.Source
+import scala.collection.JavaConverters._
object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] {
override def send[T](r: Request,
@@ -24,8 +20,13 @@ object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] {
c.setDoInput(true)
setBody(r.body, c)
- val status = c.getResponseCode
- Response(status, readResponse(c.getInputStream, responseAs))
+ try {
+ val is = c.getInputStream
+ readResponse(c, is, responseAs)
+ } catch {
+ case _: IOException if c.getResponseCode != -1 =>
+ readResponse(c, c.getErrorStream, responseAs)
+ }
}
private def setBody(body: RequestBody, c: HttpURLConnection): Unit = {
@@ -76,8 +77,19 @@ object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] {
}
}
- private def readResponse[T](is: InputStream,
- responseAs: ResponseAs[T, Nothing]): T =
+ private def readResponse[T](
+ c: HttpURLConnection,
+ is: InputStream,
+ responseAs: ResponseAs[T, Nothing]): Response[T] = {
+
+ val headers = c.getHeaderFields.asScala.toVector
+ .filter(_._1 != null)
+ .flatMap { case (k, vv) => vv.asScala.map((k, _)) }
+ Response(readResponseBody(is, responseAs), c.getResponseCode, headers)
+ }
+
+ private def readResponseBody[T](is: InputStream,
+ responseAs: ResponseAs[T, Nothing]): T =
responseAs match {
case IgnoreResponse =>
@tailrec def consume(): Unit = if (is.read() != -1) consume()
diff --git a/core/src/main/scala/com/softwaremill/sttp/Response.scala b/core/src/main/scala/com/softwaremill/sttp/Response.scala
index 4b90259..dbaf71f 100644
--- a/core/src/main/scala/com/softwaremill/sttp/Response.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/Response.scala
@@ -1,3 +1,21 @@
package com.softwaremill.sttp
-case class Response[T](status: Int, body: T)
+import scala.collection.immutable.Seq
+import scala.util.Try
+
+case class Response[T](body: T, code: Int, headers: Seq[(String, String)]) {
+ def is200: Boolean = code == 200
+ def isSuccess: Boolean = code >= 200 && code < 300
+ def isRedirect: Boolean = code >= 300 && code < 400
+ def isClientError: Boolean = code >= 400 && code < 500
+ def isServerError: Boolean = code >= 500 && code < 600
+
+ def header(h: String): Option[String] =
+ headers.find(_._1.equalsIgnoreCase(h)).map(_._2)
+ def headers(h: String): Seq[String] =
+ headers.filter(_._1.equalsIgnoreCase(h)).map(_._2)
+
+ def contentType: Option[String] = header(ContentTypeHeader)
+ def contentLength: Option[Long] =
+ header(ContentLengthHeader).flatMap(cl => Try(cl.toLong).toOption)
+}
diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala
index 79d9a17..c39f256 100644
--- a/core/src/main/scala/com/softwaremill/sttp/package.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/package.scala
@@ -130,10 +130,9 @@ package object sttp {
def header(k: String,
v: String,
replaceExisting: Boolean = false): RequestTemplate[U] = {
- val kLower = k.toLowerCase
val current =
if (replaceExisting)
- headers.filterNot(_._1.toLowerCase.contains(kLower))
+ headers.filterNot(_._1.equalsIgnoreCase(k))
else headers
this.copy(headers = current :+ (k -> v))
}
@@ -232,7 +231,8 @@ package object sttp {
val sttp: RequestTemplate[Empty] = RequestTemplate.empty
- private val ContentTypeHeader = "content-type"
+ private[sttp] val ContentTypeHeader = "content-type"
+ private[sttp] val ContentLengthHeader = "content-length"
private val Utf8 = "utf-8"
private val ApplicationOctetStreamContentType = "application/octet-stream"
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index 39a4b25..86d28a9 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -2,11 +2,12 @@ package com.softwaremill.sttp
import java.io.ByteArrayInputStream
import java.net.URI
-import java.nio.ByteBuffer
import akka.stream.ActorMaterializer
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
+import akka.http.scaladsl.model.headers._
+import akka.http.scaladsl.model.headers.CacheDirectives._
import akka.http.scaladsl.server.Directives._
import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler
import com.typesafe.scalalogging.StrictLogging
@@ -47,6 +48,14 @@ class BasicTests
}
}
}
+ } ~ path("set_headers") {
+ get {
+ respondWithHeader(`Cache-Control`(`max-age`(1000L))) {
+ respondWithHeader(`Cache-Control`(`no-cache`)) {
+ complete("ok")
+ }
+ }
+ }
}
private implicit val actorSystem: ActorSystem = ActorSystem("sttp-test")
@@ -91,6 +100,8 @@ class BasicTests
parseResponseTests()
parameterTests()
bodyTests()
+ headerTests()
+ errorsTests()
def parseResponseTests(): Unit = {
name should "parse response as string" in {
@@ -138,12 +149,7 @@ class BasicTests
fc should be(expectedPostEchoResponse)
}
- name should "post a byte buffer" in {
- val response =
- postEcho.body(ByteBuffer.wrap(testBodyBytes)).send(responseAsString)
- val fc = forceResponse.force(response).body
- fc should be(expectedPostEchoResponse)
- }
+ name should "post a byte buffer" in {}
name should "post a file" in {
val f = File.newTemporaryFile().write(testBody)
@@ -163,5 +169,33 @@ class BasicTests
} finally f.delete()
}
}
+
+ def headerTests(): Unit = {
+ val getHeaders = sttp.get(new URI(endpoint + "/set_headers"))
+
+ name should "read response headers" in {
+ val wrappedResponse = getHeaders.send(ignoreResponse)
+ val response = forceResponse.force(wrappedResponse)
+ response.headers should have length (6)
+ response.headers("Cache-Control").toSet should be(
+ Set("no-cache", "max-age=1000"))
+ response.header("Server") should be('defined)
+ response.header("server") should be('defined)
+ response.header("Server").get should startWith("akka-http")
+ response.contentType should be(Some("text/plain; charset=UTF-8"))
+ response.contentLength should be(Some(2L))
+ }
+ }
+
+ def errorsTests(): Unit = {
+ val getHeaders = sttp.post(new URI(endpoint + "/set_headers"))
+
+ name should "return 405 when method not allowed" in {
+ val response = getHeaders.send(ignoreResponse)
+ val resp = forceResponse.force(response)
+ resp.code should be(405)
+ resp.isClientError should be(true)
+ }
+ }
}
}