aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala54
-rw-r--r--build.sbt3
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala54
3 files changed, 86 insertions, 25 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
index 78d6637..df40fc9 100644
--- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
+++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
@@ -7,12 +7,13 @@ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.`Content-Type`
import akka.http.scaladsl.model.ContentTypes.`application/octet-stream`
import akka.stream.ActorMaterializer
-import akka.stream.scaladsl.Source
+import akka.stream.scaladsl.{Source, StreamConverters}
import akka.util.ByteString
import com.softwaremill.sttp._
import com.softwaremill.sttp.model._
import scala.concurrent.Future
+import scala.util.{Failure, Success, Try}
class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, Source[ByteString, Any]] {
@@ -23,7 +24,7 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future,
import as.dispatcher
override def send[T](r: Request, responseAs: ResponseAs[T, Source[ByteString, Any]]): Future[Response[T]] = {
- requestToAkka(r).flatMap(setBodyOnAkka(r, r.body, _)).flatMap(Http().singleRequest(_)).flatMap { hr =>
+ requestToAkka(r).map(setBodyOnAkka(r, r.body, _).get).flatMap(Http().singleRequest(_)).flatMap { hr =>
val code = hr.status.intValue()
bodyFromAkka(responseAs, hr).map(Response(code, _))
}
@@ -65,7 +66,7 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future,
private def requestToAkka(r: Request): Future[HttpRequest] = {
val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method))
- val parsed = r.headers.map(h => HttpHeader.parse(h._1, h._2))
+ val parsed = r.headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2))
val errors = parsed.collect {
case ParsingResult.Error(e) => e
}
@@ -80,36 +81,45 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future,
}
}
- private def setBodyOnAkka(r: Request, body: RequestBody, ar: HttpRequest): Future[HttpRequest] = body match {
- case NoBody => Future.successful(ar)
- case StringBody(b, encoding) => Future.successful(ar.withEntity(b)) // TODO
- case ByteArrayBody(b) => Future.successful(ar.withEntity(b))
- case ByteBufferBody(b) => Future.successful(ar.withEntity(ByteString(b)))
- case InputStreamBody(b) => Future.successful(ar) //TODO
- case FileBody(b) => Future.successful(ar)//TODO
- case PathBody(b) => Future.successful(ar) //TODO
- case sb@SerializableBody(_, _) => setSerializableBodyOnAkka(r, sb, ar)
- }
+ private def setBodyOnAkka(r: Request, body: RequestBody, ar: HttpRequest): Try[HttpRequest] = {
+ getContentTypeOrOctetStream(r).map { ct =>
+
+ def doSet(body: RequestBody): HttpRequest = body match {
+ case NoBody => ar
+ case StringBody(b, encoding) =>
+ val ctWithEncoding = HttpCharsets.getForKey(encoding).map(hc => ContentType.apply(ct.mediaType, () => hc)).getOrElse(ct)
+ ar.withEntity(ctWithEncoding, b.getBytes(encoding))
+ case ByteArrayBody(b) => ar.withEntity(b)
+ case ByteBufferBody(b) => ar.withEntity(ByteString(b))
+ case InputStreamBody(b) => ar.withEntity(HttpEntity(ct, StreamConverters.fromInputStream(() => b)))
+ case FileBody(b) => ar.withEntity(ct, b.toPath)
+ case PathBody(b) => ar.withEntity(ct, b)
+ case s@SerializableBody(_, _) => doSetSerializable(s)
+ }
- private def setSerializableBodyOnAkka[T](r: Request, body: SerializableBody[T], ar: HttpRequest): Future[HttpRequest] = body match {
- case SerializableBody(SourceBodySerializer, t) =>
- getContentTypeOrOctetStream(r).map(ct => ar.withEntity(HttpEntity(ct, t)))
+ def doSetSerializable[T](body: SerializableBody[T]): HttpRequest = body match {
+ case SerializableBody(SourceBodySerializer, t) => ar.withEntity(HttpEntity(ct, t))
+ case SerializableBody(f, t) => doSet(f(t))
+ }
- case SerializableBody(f, t) => setBodyOnAkka(r, f(t), ar)
+ doSet(body)
+ }
}
- private def getContentTypeOrOctetStream(r: Request): Future[ContentType] = {
+ private def getContentTypeOrOctetStream(r: Request): Try[ContentType] = {
r.headers
- .find(_._1.toLowerCase.contains(`Content-Type`.lowercaseName))
+ .find(isContentType)
.map(_._2)
.map { ct =>
ContentType.parse(ct).fold(
- errors => Future.failed(new RuntimeException(s"Cannot parse content type: $errors")),
- Future.successful)
+ errors => Failure(new RuntimeException(s"Cannot parse content type: $errors")),
+ Success(_))
}
- .getOrElse(Future.successful(`application/octet-stream`))
+ .getOrElse(Success(`application/octet-stream`))
}
+ private def isContentType(header: (String, String)) = header._1.toLowerCase.contains(`Content-Type`.lowercaseName)
+
def close(): Future[Terminated] = {
actorSystem.terminate()
}
diff --git a/build.sbt b/build.sbt
index 454cac6..09f50d0 100644
--- a/build.sbt
+++ b/build.sbt
@@ -76,6 +76,7 @@ lazy val tests: Project = (project in file("tests"))
libraryDependencies ++= Seq(
akkaHttp,
scalaTest,
- "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test"
+ "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test",
+ "com.github.pathikrit" %% "better-files" % "3.0.0"
)
) dependsOn(core, akkaHttpHandler) \ No newline at end of file
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index 0f0f076..e4ee030 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -1,6 +1,8 @@
package com.softwaremill.sttp
+import java.io.ByteArrayInputStream
import java.net.URI
+import java.nio.ByteBuffer
import akka.stream.ActorMaterializer
import akka.actor.ActorSystem
@@ -10,6 +12,7 @@ import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
+import better.files._
import scala.concurrent.Future
import scala.language.higherKinds
@@ -21,13 +24,13 @@ class BasicTests extends FlatSpec with Matchers with BeforeAndAfterAll with Scal
path("echo") {
get {
parameterMap { params =>
- complete(s"GET /echo ${paramsToString(params)}")
+ complete(List("GET", "/echo", paramsToString(params)).filter(_.nonEmpty).mkString(" "))
}
} ~
post {
parameterMap { params =>
entity(as[String]) { body: String =>
- complete(s"POST /echo ${paramsToString(params)} $body")
+ complete(List("POST", "/echo", paramsToString(params), body).filter(_.nonEmpty).mkString(" "))
}
}
}
@@ -69,5 +72,52 @@ class BasicTests extends FlatSpec with Matchers with BeforeAndAfterAll with Scal
val fc = forceResponse.force(response).body
fc should be ("GET /echo p1=v1 p2=v2")
}
+
+ val postEcho = sttp.post(new URI(endpoint + "/echo"))
+ val testBody = "this is the body"
+ val testBodyBytes = testBody.getBytes("UTF-8")
+ val expectedPostEchoResponse = "POST /echo this is the body"
+
+ name should "post a string" in {
+ val response = postEcho.data(testBody).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be (expectedPostEchoResponse)
+ }
+
+ name should "post a byte array" in {
+ val response = postEcho.data(testBodyBytes).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be (expectedPostEchoResponse)
+ }
+
+ name should "post an input stream" in {
+ val response = postEcho.data(new ByteArrayInputStream(testBodyBytes)).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be (expectedPostEchoResponse)
+ }
+
+ name should "post a byte buffer" in {
+ val response = postEcho.data(ByteBuffer.wrap(testBodyBytes)).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be (expectedPostEchoResponse)
+ }
+
+ name should "post a file" in {
+ val f = File.newTemporaryFile().write(testBody)
+ try {
+ val response = postEcho.data(f.toJava).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be(expectedPostEchoResponse)
+ } finally f.delete()
+ }
+
+ name should "post a path" in {
+ val f = File.newTemporaryFile().write(testBody)
+ try {
+ val response = postEcho.data(f.toJava.toPath).send(responseAsString)
+ val fc = forceResponse.force(response).body
+ fc should be(expectedPostEchoResponse)
+ } finally f.delete()
+ }
}
}