diff options
-rw-r--r-- | akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala | 101 | ||||
-rw-r--r-- | core/src/main/scala/com/softwaremill/sttp/Multipart.scala (renamed from core/src/main/scala/com/softwaremill/sttp/MultiPart.scala) | 12 | ||||
-rw-r--r-- | core/src/main/scala/com/softwaremill/sttp/RequestT.scala | 19 | ||||
-rw-r--r-- | core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala | 21 | ||||
-rw-r--r-- | core/src/main/scala/com/softwaremill/sttp/package.scala | 87 | ||||
-rw-r--r-- | tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala | 34 |
6 files changed, 220 insertions, 54 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala index d1dcc22..b296e5c 100644 --- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala +++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala @@ -6,7 +6,7 @@ import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.coding.{Deflate, Gzip, NoCoding} import akka.http.scaladsl.model.HttpHeader.ParsingResult -import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.{Multipart => AkkaMultipart, _} import akka.http.scaladsl.model.headers.{HttpEncodings, `Content-Type`} import akka.http.scaladsl.model.ContentTypes.`application/octet-stream` import akka.stream.ActorMaterializer @@ -110,8 +110,13 @@ class AkkaHttpSttpHandler private (actorSystem: ActorSystem, private def requestToAkka(r: Request[_, S]): Try[HttpRequest] = { val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method)) + headersToAkka(r.headers).map(ar.withHeaders) + } + + private def headersToAkka( + headers: Seq[(String, String)]): Try[Seq[HttpHeader]] = { val parsed = - r.headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2)) + headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2)) val errors = parsed.collect { case ParsingResult.Error(e) => e } @@ -120,41 +125,91 @@ class AkkaHttpSttpHandler private (actorSystem: ActorSystem, case ParsingResult.Ok(h, _) => h } - Success(ar.withHeaders(headers.toList)) + Success(headers.toList) } else { Failure(new RuntimeException(s"Cannot parse headers: $errors")) } } + private def traverseTry[T](l: Seq[Try[T]]): Try[Seq[T]] = { + // https://stackoverflow.com/questions/15495678/flatten-scala-try + val (ss: Seq[Success[T]] @unchecked, fs: Seq[Failure[T]] @unchecked) = + l.partition(_.isSuccess) + + if (fs.isEmpty) Success(ss.map(_.get)) + else Failure[Seq[T]](fs.head.exception) + } + private def setBodyOnAkka(r: Request[_, S], body: RequestBody[S], ar: HttpRequest): Try[HttpRequest] = { - getContentTypeOrOctetStream(r).map { ct => - def doSet(body: RequestBody[S]): HttpRequest = body match { - case NoBody => ar + def ctWithEncoding(ct: ContentType, encoding: String) = + HttpCharsets + .getForKey(encoding) + .map(hc => ContentType.apply(ct.mediaType, () => hc)) + .getOrElse(ct) + + def toBodyPart(mp: Multipart): Try[AkkaMultipart.FormData.BodyPart] = { + def entity(ct: ContentType) = mp.body match { case StringBody(b, encoding, _) => - val ctWithEncoding = HttpCharsets - .getForKey(encoding) - .map(hc => ContentType.apply(ct.mediaType, () => hc)) - .getOrElse(ct) - ar.withEntity(ctWithEncoding, b.getBytes(encoding)) - case ByteArrayBody(b, _) => ar.withEntity(b) - case ByteBufferBody(b, _) => ar.withEntity(ByteString(b)) - case InputStreamBody(b, _) => - ar.withEntity( - HttpEntity(ct, StreamConverters.fromInputStream(() => b))) - case PathBody(b, _) => ar.withEntity(ct, b) - case StreamBody(s) => ar.withEntity(HttpEntity(ct, s)) + HttpEntity(ctWithEncoding(ct, encoding), b.getBytes(encoding)) + case ByteArrayBody(b, _) => HttpEntity(ct, b) + case ByteBufferBody(b, _) => HttpEntity(ct, ByteString(b)) + case isb: InputStreamBody => + HttpEntity + .IndefiniteLength(ct, StreamConverters.fromInputStream(() => isb.b)) + case PathBody(b, _) => HttpEntity.fromPath(ct, b) + } + + for { + ct <- parseContentTypeOrOctetStream(mp.contentType) + headers <- headersToAkka(mp.additionalHeaders.toList) + } yield { + val dispositionParams = + mp.fileName.fold(Map.empty[String, String])(fn => + Map("filename" -> fn)) + + AkkaMultipart.FormData.BodyPart(mp.name, + entity(ct), + dispositionParams, + headers) } + } - doSet(body) + parseContentTypeOrOctetStream(r).flatMap { ct => + body match { + case NoBody => Success(ar) + case StringBody(b, encoding, _) => + Success( + ar.withEntity(ctWithEncoding(ct, encoding), b.getBytes(encoding))) + case ByteArrayBody(b, _) => Success(ar.withEntity(HttpEntity(ct, b))) + case ByteBufferBody(b, _) => + Success(ar.withEntity(HttpEntity(ct, ByteString(b)))) + case InputStreamBody(b, _) => + Success( + ar.withEntity( + HttpEntity(ct, StreamConverters.fromInputStream(() => b)))) + case PathBody(b, _) => Success(ar.withEntity(ct, b)) + case StreamBody(s) => Success(ar.withEntity(HttpEntity(ct, s))) + case MultipartBody(ps) => + traverseTry(ps.map(toBodyPart)) + .map(bodyParts => + ar.withEntity(AkkaMultipart.FormData(bodyParts: _*).toEntity())) + } } } - private def getContentTypeOrOctetStream(r: Request[_, S]): Try[ContentType] = { - r.headers - .find(isContentType) - .map(_._2) + private def parseContentTypeOrOctetStream( + r: Request[_, S]): Try[ContentType] = { + parseContentTypeOrOctetStream( + r.headers + .find(isContentType) + .map(_._2)) + } + + private def parseContentTypeOrOctetStream( + ctHeader: Option[String]): Try[ContentType] = { + ctHeader .map { ct => ContentType .parse(ct) diff --git a/core/src/main/scala/com/softwaremill/sttp/MultiPart.scala b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala index eb7badb..c54c615 100644 --- a/core/src/main/scala/com/softwaremill/sttp/MultiPart.scala +++ b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala @@ -3,17 +3,17 @@ package com.softwaremill.sttp import com.softwaremill.sttp.model.BasicRequestBody /** - * Use the factory methods `multiPart` to conveniently create instances of + * Use the factory methods `multipart` to conveniently create instances of * this class. A part can be then further customised using `fileName`, * `contentType` and `header` methods. */ -case class MultiPart(name: String, - data: BasicRequestBody, +case class Multipart(name: String, + body: BasicRequestBody, fileName: Option[String] = None, contentType: Option[String] = None, additionalHeaders: Map[String, String] = Map()) { - def fileName(v: String): MultiPart = copy(fileName = Some(v)) - def contentType(v: String): MultiPart = copy(contentType = Some(v)) - def header(k: String, v: String): MultiPart = + def fileName(v: String): Multipart = copy(fileName = Some(v)) + def contentType(v: String): Multipart = copy(contentType = Some(v)) + def header(k: String, v: String): Multipart = copy(additionalHeaders = additionalHeaders + (k -> v)) } diff --git a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala index b785e7c..1461ac6 100644 --- a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala +++ b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala @@ -1,7 +1,6 @@ package com.softwaremill.sttp import java.io.{File, InputStream} -import java.net.URLEncoder import java.nio.ByteBuffer import java.nio.file.Path import java.util.Base64 @@ -200,6 +199,12 @@ case class RequestT[U[_], T, +S]( def body[B: BodySerializer](b: B): RequestT[U, T, S] = withBasicBody(implicitly[BodySerializer[B]].apply(b)) + def multipartBody(ps: Seq[Multipart]): RequestT[U, T, S] = + this.copy(body = MultipartBody(ps)) + + def multipartBody(p1: Multipart, ps: Multipart*): RequestT[U, T, S] = + this.copy(body = MultipartBody(p1 :: ps.toList)) + def streamBody[S2 >: S](b: S2): RequestT[U, T, S2] = copy[U, T, S2](body = StreamBody(b)) @@ -248,16 +253,10 @@ case class RequestT[U[_], T, +S]( private def formDataBody(fs: Seq[(String, String)], encoding: String): RequestT[U, T, S] = { - val b = fs - .map { - case (key, value) => - URLEncoder.encode(key, encoding) + "=" + - URLEncoder.encode(value, encoding) - } - .mkString("&") + val b = RequestBody.paramsToStringBody(fs, encoding) setContentTypeIfMissing(ApplicationFormContentType) - .setContentLengthIfMissing(b.getBytes(encoding).length) - .copy(body = StringBody(b, encoding)) + .setContentLengthIfMissing(b.s.getBytes(encoding).length) + .copy(body = b) } } diff --git a/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala b/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala index 7499048..5d43298 100644 --- a/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala +++ b/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala @@ -1,11 +1,14 @@ package com.softwaremill.sttp.model import java.io.InputStream +import java.net.URLEncoder import java.nio.ByteBuffer import java.nio.file.Path import com.softwaremill.sttp._ +import scala.collection.immutable.Seq + sealed trait RequestBody[+S] case object NoBody extends RequestBody[Nothing] @@ -40,3 +43,21 @@ case class PathBody( ) extends BasicRequestBody case class StreamBody[S](s: S) extends RequestBody[S] + +case class MultipartBody(parts: Seq[Multipart]) extends RequestBody[Nothing] + +object RequestBody { + private[sttp] def paramsToStringBody(fs: Seq[(String, String)], + encoding: String): StringBody = { + + val b = fs + .map { + case (key, value) => + URLEncoder.encode(key, encoding) + "=" + + URLEncoder.encode(value, encoding) + } + .mkString("&") + + StringBody(b, encoding) + } +} diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index 884d2f9..bbd777e 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -96,14 +96,14 @@ package object sttp { overwrite: Boolean = false): ResponseAs[Path, Nothing] = ResponseAsFile(path.toFile, overwrite).map(_.toPath) - // multi part factory methods + // multipart factory methods /** * Content type will be set to `text/plain` with `utf-8` encoding, can be * overridden later using the `contentType` method. */ - def multiPart(name: String, data: String): MultiPart = - MultiPart(name, + def multipart(name: String, data: String): Multipart = + Multipart(name, StringBody(data, Utf8), contentType = Some(contentTypeWithEncoding(TextPlainContentType, Utf8))) @@ -112,8 +112,8 @@ package object sttp { * Content type will be set to `text/plain` with `utf-8` encoding, can be * overridden later using the `contentType` method. */ - def multiPart(name: String, data: String, encoding: String): MultiPart = - MultiPart(name, + def multipart(name: String, data: String, encoding: String): Multipart = + Multipart(name, StringBody(data, encoding), contentType = Some(contentTypeWithEncoding(TextPlainContentType, Utf8))) @@ -122,8 +122,8 @@ package object sttp { * Content type will be set to `application/octet-stream`, can be overridden * later using the `contentType` method. */ - def multiPart(name: String, data: Array[Byte]): MultiPart = - MultiPart(name, + def multipart(name: String, data: Array[Byte]): Multipart = + Multipart(name, ByteArrayBody(data), contentType = Some(ApplicationOctetStreamContentType)) @@ -131,8 +131,8 @@ package object sttp { * Content type will be set to `application/octet-stream`, can be overridden * later using the `contentType` method. */ - def multiPart(name: String, data: ByteBuffer): MultiPart = - MultiPart(name, + def multipart(name: String, data: ByteBuffer): Multipart = + Multipart(name, ByteBufferBody(data), contentType = Some(ApplicationOctetStreamContentType)) @@ -140,8 +140,8 @@ package object sttp { * Content type will be set to `application/octet-stream`, can be overridden * later using the `contentType` method. */ - def multiPart(name: String, data: InputStream): MultiPart = - MultiPart(name, + def multipart(name: String, data: InputStream): Multipart = + Multipart(name, InputStreamBody(data), contentType = Some(ApplicationOctetStreamContentType)) @@ -149,19 +149,76 @@ package object sttp { * Content type will be set to `application/octet-stream`, can be overridden * later using the `contentType` method. */ - def multiPart(name: String, data: File): MultiPart = - multiPart(name, data.toPath) + def multipart(name: String, data: File): Multipart = + multipart(name, data.toPath) /** * Content type will be set to `application/octet-stream`, can be overridden * later using the `contentType` method. */ - def multiPart(name: String, data: Path): MultiPart = - MultiPart(name, + def multipart(name: String, data: Path): Multipart = + Multipart(name, PathBody(data), fileName = Some(data.getFileName.toString), contentType = Some(ApplicationOctetStreamContentType)) + /** + * Encodes the given parameters as form data using `utf-8`. + * + * Content type will be set to `application/x-www-form-urlencoded`, can be + * overridden later using the `contentType` method. + */ + def multipart(name: String, fs: Map[String, String]): Multipart = + Multipart(name, + RequestBody.paramsToStringBody(fs.toList, Utf8), + contentType = Some(ApplicationFormContentType)) + + /** + * Encodes the given parameters as form data. + * + * Content type will be set to `application/x-www-form-urlencoded`, can be + * overridden later using the `contentType` method. + */ + def multipart(name: String, + fs: Map[String, String], + encoding: String): Multipart = + Multipart(name, + RequestBody.paramsToStringBody(fs.toList, encoding), + contentType = Some(ApplicationFormContentType)) + + /** + * Encodes the given parameters as form data using `utf-8`. + * + * Content type will be set to `application/x-www-form-urlencoded`, can be + * overridden later using the `contentType` method. + */ + def multipart(name: String, fs: Seq[(String, String)]): Multipart = + Multipart(name, + RequestBody.paramsToStringBody(fs, Utf8), + contentType = Some(ApplicationFormContentType)) + + /** + * Encodes the given parameters as form data. + * + * Content type will be set to `application/x-www-form-urlencoded`, can be + * overridden later using the `contentType` method. + */ + def multipart(name: String, + fs: Seq[(String, String)], + encoding: String): Multipart = + Multipart(name, + RequestBody.paramsToStringBody(fs, encoding), + contentType = Some(ApplicationFormContentType)) + + /** + * Content type will be set to `application/octet-stream`, can be + * overridden later using the `contentType` method. + */ + def multipart[B: BodySerializer](name: String, b: B): Multipart = + Multipart(name, + implicitly[BodySerializer[B]].apply(b), + contentType = Some(ApplicationOctetStreamContentType)) + // util private[sttp] def contentTypeWithEncoding(ct: String, enc: String) = diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 9b64271..b2918fa 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -12,6 +12,7 @@ import akka.http.scaladsl.model.headers.CacheDirectives._ import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.Route import akka.http.scaladsl.server.directives.Credentials +import akka.util.ByteString import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler import com.typesafe.scalalogging.StrictLogging import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures} @@ -26,6 +27,7 @@ import com.softwaremill.sttp.okhttp.{ OkHttpFutureClientHandler, OkHttpSyncClientHandler } +import scala.concurrent.ExecutionContext.Implicits.global import scala.language.higherKinds @@ -136,6 +138,20 @@ class BasicTests } ~ path("text") { getFromFile(textFile) } + } ~ pathPrefix("multipart") { + entity(as[akka.http.scaladsl.model.Multipart.FormData]) { fd => + complete { + fd.parts + .mapAsync(1) { p => + val fv = p.entity.dataBytes.runFold(ByteString())(_ ++ _) + fv.map(_.utf8String) + .map(v => + p.name + "=" + v + p.filename.fold("")(fn => s" ($fn)")) + } + .runFold(Vector.empty[String])(_ :+ _) + .map(v => v.mkString(", ")) + } + } } override def port = 51823 @@ -184,6 +200,7 @@ class BasicTests authTests() compressionTests() downloadFileTests() +// multipartTests() def parseResponseTests(): Unit = { name should "parse response as string" in { @@ -515,6 +532,23 @@ class BasicTests path.toFile should haveSameContentAs(textFile) } } + + def multipartTests(): Unit = { + val mp = sttp.post(uri"$endpoint/multipart") + + name should "send a multipart message" in { + val req = mp.multipartBody(multipart("p1", "v1"), multipart("p2", "v2")) + val resp = req.send().force() + resp.body should be("p1=v1, p2=v2") + } + + name should "send a multipart message with filenames" in { + val req = mp.multipartBody(multipart("p1", "v1").fileName("f1"), + multipart("p2", "v2").fileName("f2")) + val resp = req.send().force() + resp.body should be("p1=v1 (f1), p2=v2 (f2)") + } + } } override protected def afterAll(): Unit = { |