diff options
-rw-r--r-- | akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala | 54 | ||||
-rw-r--r-- | build.sbt | 3 | ||||
-rw-r--r-- | tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala | 54 |
3 files changed, 86 insertions, 25 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala index 78d6637..df40fc9 100644 --- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala +++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala @@ -7,12 +7,13 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.model.headers.`Content-Type` import akka.http.scaladsl.model.ContentTypes.`application/octet-stream` import akka.stream.ActorMaterializer -import akka.stream.scaladsl.Source +import akka.stream.scaladsl.{Source, StreamConverters} import akka.util.ByteString import com.softwaremill.sttp._ import com.softwaremill.sttp.model._ import scala.concurrent.Future +import scala.util.{Failure, Success, Try} class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, Source[ByteString, Any]] { @@ -23,7 +24,7 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, import as.dispatcher override def send[T](r: Request, responseAs: ResponseAs[T, Source[ByteString, Any]]): Future[Response[T]] = { - requestToAkka(r).flatMap(setBodyOnAkka(r, r.body, _)).flatMap(Http().singleRequest(_)).flatMap { hr => + requestToAkka(r).map(setBodyOnAkka(r, r.body, _).get).flatMap(Http().singleRequest(_)).flatMap { hr => val code = hr.status.intValue() bodyFromAkka(responseAs, hr).map(Response(code, _)) } @@ -65,7 +66,7 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, private def requestToAkka(r: Request): Future[HttpRequest] = { val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method)) - val parsed = r.headers.map(h => HttpHeader.parse(h._1, h._2)) + val parsed = r.headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2)) val errors = parsed.collect { case ParsingResult.Error(e) => e } @@ -80,36 +81,45 @@ class AkkaHttpSttpHandler(actorSystem: ActorSystem) extends SttpHandler[Future, } } - private def setBodyOnAkka(r: Request, body: RequestBody, ar: HttpRequest): Future[HttpRequest] = body match { - case NoBody => Future.successful(ar) - case StringBody(b, encoding) => Future.successful(ar.withEntity(b)) // TODO - case ByteArrayBody(b) => Future.successful(ar.withEntity(b)) - case ByteBufferBody(b) => Future.successful(ar.withEntity(ByteString(b))) - case InputStreamBody(b) => Future.successful(ar) //TODO - case FileBody(b) => Future.successful(ar)//TODO - case PathBody(b) => Future.successful(ar) //TODO - case sb@SerializableBody(_, _) => setSerializableBodyOnAkka(r, sb, ar) - } + private def setBodyOnAkka(r: Request, body: RequestBody, ar: HttpRequest): Try[HttpRequest] = { + getContentTypeOrOctetStream(r).map { ct => + + def doSet(body: RequestBody): HttpRequest = body match { + case NoBody => ar + case StringBody(b, encoding) => + val ctWithEncoding = HttpCharsets.getForKey(encoding).map(hc => ContentType.apply(ct.mediaType, () => hc)).getOrElse(ct) + ar.withEntity(ctWithEncoding, b.getBytes(encoding)) + case ByteArrayBody(b) => ar.withEntity(b) + case ByteBufferBody(b) => ar.withEntity(ByteString(b)) + case InputStreamBody(b) => ar.withEntity(HttpEntity(ct, StreamConverters.fromInputStream(() => b))) + case FileBody(b) => ar.withEntity(ct, b.toPath) + case PathBody(b) => ar.withEntity(ct, b) + case s@SerializableBody(_, _) => doSetSerializable(s) + } - private def setSerializableBodyOnAkka[T](r: Request, body: SerializableBody[T], ar: HttpRequest): Future[HttpRequest] = body match { - case SerializableBody(SourceBodySerializer, t) => - getContentTypeOrOctetStream(r).map(ct => ar.withEntity(HttpEntity(ct, t))) + def doSetSerializable[T](body: SerializableBody[T]): HttpRequest = body match { + case SerializableBody(SourceBodySerializer, t) => ar.withEntity(HttpEntity(ct, t)) + case SerializableBody(f, t) => doSet(f(t)) + } - case SerializableBody(f, t) => setBodyOnAkka(r, f(t), ar) + doSet(body) + } } - private def getContentTypeOrOctetStream(r: Request): Future[ContentType] = { + private def getContentTypeOrOctetStream(r: Request): Try[ContentType] = { r.headers - .find(_._1.toLowerCase.contains(`Content-Type`.lowercaseName)) + .find(isContentType) .map(_._2) .map { ct => ContentType.parse(ct).fold( - errors => Future.failed(new RuntimeException(s"Cannot parse content type: $errors")), - Future.successful) + errors => Failure(new RuntimeException(s"Cannot parse content type: $errors")), + Success(_)) } - .getOrElse(Future.successful(`application/octet-stream`)) + .getOrElse(Success(`application/octet-stream`)) } + private def isContentType(header: (String, String)) = header._1.toLowerCase.contains(`Content-Type`.lowercaseName) + def close(): Future[Terminated] = { actorSystem.terminate() } @@ -76,6 +76,7 @@ lazy val tests: Project = (project in file("tests")) libraryDependencies ++= Seq( akkaHttp, scalaTest, - "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test" + "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0" % "test", + "com.github.pathikrit" %% "better-files" % "3.0.0" ) ) dependsOn(core, akkaHttpHandler)
\ No newline at end of file diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 0f0f076..e4ee030 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -1,6 +1,8 @@ package com.softwaremill.sttp +import java.io.ByteArrayInputStream import java.net.URI +import java.nio.ByteBuffer import akka.stream.ActorMaterializer import akka.actor.ActorSystem @@ -10,6 +12,7 @@ import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler import com.typesafe.scalalogging.StrictLogging import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures} import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers} +import better.files._ import scala.concurrent.Future import scala.language.higherKinds @@ -21,13 +24,13 @@ class BasicTests extends FlatSpec with Matchers with BeforeAndAfterAll with Scal path("echo") { get { parameterMap { params => - complete(s"GET /echo ${paramsToString(params)}") + complete(List("GET", "/echo", paramsToString(params)).filter(_.nonEmpty).mkString(" ")) } } ~ post { parameterMap { params => entity(as[String]) { body: String => - complete(s"POST /echo ${paramsToString(params)} $body") + complete(List("POST", "/echo", paramsToString(params), body).filter(_.nonEmpty).mkString(" ")) } } } @@ -69,5 +72,52 @@ class BasicTests extends FlatSpec with Matchers with BeforeAndAfterAll with Scal val fc = forceResponse.force(response).body fc should be ("GET /echo p1=v1 p2=v2") } + + val postEcho = sttp.post(new URI(endpoint + "/echo")) + val testBody = "this is the body" + val testBodyBytes = testBody.getBytes("UTF-8") + val expectedPostEchoResponse = "POST /echo this is the body" + + name should "post a string" in { + val response = postEcho.data(testBody).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be (expectedPostEchoResponse) + } + + name should "post a byte array" in { + val response = postEcho.data(testBodyBytes).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be (expectedPostEchoResponse) + } + + name should "post an input stream" in { + val response = postEcho.data(new ByteArrayInputStream(testBodyBytes)).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be (expectedPostEchoResponse) + } + + name should "post a byte buffer" in { + val response = postEcho.data(ByteBuffer.wrap(testBodyBytes)).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be (expectedPostEchoResponse) + } + + name should "post a file" in { + val f = File.newTemporaryFile().write(testBody) + try { + val response = postEcho.data(f.toJava).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be(expectedPostEchoResponse) + } finally f.delete() + } + + name should "post a path" in { + val f = File.newTemporaryFile().write(testBody) + try { + val response = postEcho.data(f.toJava.toPath).send(responseAsString) + val fc = forceResponse.force(response).body + fc should be(expectedPostEchoResponse) + } finally f.delete() + } } } |