aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala101
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/Multipart.scala (renamed from core/src/main/scala/com/softwaremill/sttp/MultiPart.scala)12
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/RequestT.scala19
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala21
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/package.scala87
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala34
6 files changed, 220 insertions, 54 deletions
diff --git a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
index d1dcc22..b296e5c 100644
--- a/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
+++ b/akka-http-handler/src/main/scala/com/softwaremill/sttp/akkahttp/AkkaHttpSttpHandler.scala
@@ -6,7 +6,7 @@ import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.coding.{Deflate, Gzip, NoCoding}
import akka.http.scaladsl.model.HttpHeader.ParsingResult
-import akka.http.scaladsl.model._
+import akka.http.scaladsl.model.{Multipart => AkkaMultipart, _}
import akka.http.scaladsl.model.headers.{HttpEncodings, `Content-Type`}
import akka.http.scaladsl.model.ContentTypes.`application/octet-stream`
import akka.stream.ActorMaterializer
@@ -110,8 +110,13 @@ class AkkaHttpSttpHandler private (actorSystem: ActorSystem,
private def requestToAkka(r: Request[_, S]): Try[HttpRequest] = {
val ar = HttpRequest(uri = r.uri.toString, method = methodToAkka(r.method))
+ headersToAkka(r.headers).map(ar.withHeaders)
+ }
+
+ private def headersToAkka(
+ headers: Seq[(String, String)]): Try[Seq[HttpHeader]] = {
val parsed =
- r.headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2))
+ headers.filterNot(isContentType).map(h => HttpHeader.parse(h._1, h._2))
val errors = parsed.collect {
case ParsingResult.Error(e) => e
}
@@ -120,41 +125,91 @@ class AkkaHttpSttpHandler private (actorSystem: ActorSystem,
case ParsingResult.Ok(h, _) => h
}
- Success(ar.withHeaders(headers.toList))
+ Success(headers.toList)
} else {
Failure(new RuntimeException(s"Cannot parse headers: $errors"))
}
}
+ private def traverseTry[T](l: Seq[Try[T]]): Try[Seq[T]] = {
+ // https://stackoverflow.com/questions/15495678/flatten-scala-try
+ val (ss: Seq[Success[T]] @unchecked, fs: Seq[Failure[T]] @unchecked) =
+ l.partition(_.isSuccess)
+
+ if (fs.isEmpty) Success(ss.map(_.get))
+ else Failure[Seq[T]](fs.head.exception)
+ }
+
private def setBodyOnAkka(r: Request[_, S],
body: RequestBody[S],
ar: HttpRequest): Try[HttpRequest] = {
- getContentTypeOrOctetStream(r).map { ct =>
- def doSet(body: RequestBody[S]): HttpRequest = body match {
- case NoBody => ar
+ def ctWithEncoding(ct: ContentType, encoding: String) =
+ HttpCharsets
+ .getForKey(encoding)
+ .map(hc => ContentType.apply(ct.mediaType, () => hc))
+ .getOrElse(ct)
+
+ def toBodyPart(mp: Multipart): Try[AkkaMultipart.FormData.BodyPart] = {
+ def entity(ct: ContentType) = mp.body match {
case StringBody(b, encoding, _) =>
- val ctWithEncoding = HttpCharsets
- .getForKey(encoding)
- .map(hc => ContentType.apply(ct.mediaType, () => hc))
- .getOrElse(ct)
- ar.withEntity(ctWithEncoding, b.getBytes(encoding))
- case ByteArrayBody(b, _) => ar.withEntity(b)
- case ByteBufferBody(b, _) => ar.withEntity(ByteString(b))
- case InputStreamBody(b, _) =>
- ar.withEntity(
- HttpEntity(ct, StreamConverters.fromInputStream(() => b)))
- case PathBody(b, _) => ar.withEntity(ct, b)
- case StreamBody(s) => ar.withEntity(HttpEntity(ct, s))
+ HttpEntity(ctWithEncoding(ct, encoding), b.getBytes(encoding))
+ case ByteArrayBody(b, _) => HttpEntity(ct, b)
+ case ByteBufferBody(b, _) => HttpEntity(ct, ByteString(b))
+ case isb: InputStreamBody =>
+ HttpEntity
+ .IndefiniteLength(ct, StreamConverters.fromInputStream(() => isb.b))
+ case PathBody(b, _) => HttpEntity.fromPath(ct, b)
+ }
+
+ for {
+ ct <- parseContentTypeOrOctetStream(mp.contentType)
+ headers <- headersToAkka(mp.additionalHeaders.toList)
+ } yield {
+ val dispositionParams =
+ mp.fileName.fold(Map.empty[String, String])(fn =>
+ Map("filename" -> fn))
+
+ AkkaMultipart.FormData.BodyPart(mp.name,
+ entity(ct),
+ dispositionParams,
+ headers)
}
+ }
- doSet(body)
+ parseContentTypeOrOctetStream(r).flatMap { ct =>
+ body match {
+ case NoBody => Success(ar)
+ case StringBody(b, encoding, _) =>
+ Success(
+ ar.withEntity(ctWithEncoding(ct, encoding), b.getBytes(encoding)))
+ case ByteArrayBody(b, _) => Success(ar.withEntity(HttpEntity(ct, b)))
+ case ByteBufferBody(b, _) =>
+ Success(ar.withEntity(HttpEntity(ct, ByteString(b))))
+ case InputStreamBody(b, _) =>
+ Success(
+ ar.withEntity(
+ HttpEntity(ct, StreamConverters.fromInputStream(() => b))))
+ case PathBody(b, _) => Success(ar.withEntity(ct, b))
+ case StreamBody(s) => Success(ar.withEntity(HttpEntity(ct, s)))
+ case MultipartBody(ps) =>
+ traverseTry(ps.map(toBodyPart))
+ .map(bodyParts =>
+ ar.withEntity(AkkaMultipart.FormData(bodyParts: _*).toEntity()))
+ }
}
}
- private def getContentTypeOrOctetStream(r: Request[_, S]): Try[ContentType] = {
- r.headers
- .find(isContentType)
- .map(_._2)
+ private def parseContentTypeOrOctetStream(
+ r: Request[_, S]): Try[ContentType] = {
+ parseContentTypeOrOctetStream(
+ r.headers
+ .find(isContentType)
+ .map(_._2))
+ }
+
+ private def parseContentTypeOrOctetStream(
+ ctHeader: Option[String]): Try[ContentType] = {
+ ctHeader
.map { ct =>
ContentType
.parse(ct)
diff --git a/core/src/main/scala/com/softwaremill/sttp/MultiPart.scala b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala
index eb7badb..c54c615 100644
--- a/core/src/main/scala/com/softwaremill/sttp/MultiPart.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala
@@ -3,17 +3,17 @@ package com.softwaremill.sttp
import com.softwaremill.sttp.model.BasicRequestBody
/**
- * Use the factory methods `multiPart` to conveniently create instances of
+ * Use the factory methods `multipart` to conveniently create instances of
* this class. A part can be then further customised using `fileName`,
* `contentType` and `header` methods.
*/
-case class MultiPart(name: String,
- data: BasicRequestBody,
+case class Multipart(name: String,
+ body: BasicRequestBody,
fileName: Option[String] = None,
contentType: Option[String] = None,
additionalHeaders: Map[String, String] = Map()) {
- def fileName(v: String): MultiPart = copy(fileName = Some(v))
- def contentType(v: String): MultiPart = copy(contentType = Some(v))
- def header(k: String, v: String): MultiPart =
+ def fileName(v: String): Multipart = copy(fileName = Some(v))
+ def contentType(v: String): Multipart = copy(contentType = Some(v))
+ def header(k: String, v: String): Multipart =
copy(additionalHeaders = additionalHeaders + (k -> v))
}
diff --git a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
index b785e7c..1461ac6 100644
--- a/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/RequestT.scala
@@ -1,7 +1,6 @@
package com.softwaremill.sttp
import java.io.{File, InputStream}
-import java.net.URLEncoder
import java.nio.ByteBuffer
import java.nio.file.Path
import java.util.Base64
@@ -200,6 +199,12 @@ case class RequestT[U[_], T, +S](
def body[B: BodySerializer](b: B): RequestT[U, T, S] =
withBasicBody(implicitly[BodySerializer[B]].apply(b))
+ def multipartBody(ps: Seq[Multipart]): RequestT[U, T, S] =
+ this.copy(body = MultipartBody(ps))
+
+ def multipartBody(p1: Multipart, ps: Multipart*): RequestT[U, T, S] =
+ this.copy(body = MultipartBody(p1 :: ps.toList))
+
def streamBody[S2 >: S](b: S2): RequestT[U, T, S2] =
copy[U, T, S2](body = StreamBody(b))
@@ -248,16 +253,10 @@ case class RequestT[U[_], T, +S](
private def formDataBody(fs: Seq[(String, String)],
encoding: String): RequestT[U, T, S] = {
- val b = fs
- .map {
- case (key, value) =>
- URLEncoder.encode(key, encoding) + "=" +
- URLEncoder.encode(value, encoding)
- }
- .mkString("&")
+ val b = RequestBody.paramsToStringBody(fs, encoding)
setContentTypeIfMissing(ApplicationFormContentType)
- .setContentLengthIfMissing(b.getBytes(encoding).length)
- .copy(body = StringBody(b, encoding))
+ .setContentLengthIfMissing(b.s.getBytes(encoding).length)
+ .copy(body = b)
}
}
diff --git a/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala b/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala
index 7499048..5d43298 100644
--- a/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/model/RequestBody.scala
@@ -1,11 +1,14 @@
package com.softwaremill.sttp.model
import java.io.InputStream
+import java.net.URLEncoder
import java.nio.ByteBuffer
import java.nio.file.Path
import com.softwaremill.sttp._
+import scala.collection.immutable.Seq
+
sealed trait RequestBody[+S]
case object NoBody extends RequestBody[Nothing]
@@ -40,3 +43,21 @@ case class PathBody(
) extends BasicRequestBody
case class StreamBody[S](s: S) extends RequestBody[S]
+
+case class MultipartBody(parts: Seq[Multipart]) extends RequestBody[Nothing]
+
+object RequestBody {
+ private[sttp] def paramsToStringBody(fs: Seq[(String, String)],
+ encoding: String): StringBody = {
+
+ val b = fs
+ .map {
+ case (key, value) =>
+ URLEncoder.encode(key, encoding) + "=" +
+ URLEncoder.encode(value, encoding)
+ }
+ .mkString("&")
+
+ StringBody(b, encoding)
+ }
+}
diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala
index 884d2f9..bbd777e 100644
--- a/core/src/main/scala/com/softwaremill/sttp/package.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/package.scala
@@ -96,14 +96,14 @@ package object sttp {
overwrite: Boolean = false): ResponseAs[Path, Nothing] =
ResponseAsFile(path.toFile, overwrite).map(_.toPath)
- // multi part factory methods
+ // multipart factory methods
/**
* Content type will be set to `text/plain` with `utf-8` encoding, can be
* overridden later using the `contentType` method.
*/
- def multiPart(name: String, data: String): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: String): Multipart =
+ Multipart(name,
StringBody(data, Utf8),
contentType =
Some(contentTypeWithEncoding(TextPlainContentType, Utf8)))
@@ -112,8 +112,8 @@ package object sttp {
* Content type will be set to `text/plain` with `utf-8` encoding, can be
* overridden later using the `contentType` method.
*/
- def multiPart(name: String, data: String, encoding: String): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: String, encoding: String): Multipart =
+ Multipart(name,
StringBody(data, encoding),
contentType =
Some(contentTypeWithEncoding(TextPlainContentType, Utf8)))
@@ -122,8 +122,8 @@ package object sttp {
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
*/
- def multiPart(name: String, data: Array[Byte]): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: Array[Byte]): Multipart =
+ Multipart(name,
ByteArrayBody(data),
contentType = Some(ApplicationOctetStreamContentType))
@@ -131,8 +131,8 @@ package object sttp {
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
*/
- def multiPart(name: String, data: ByteBuffer): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: ByteBuffer): Multipart =
+ Multipart(name,
ByteBufferBody(data),
contentType = Some(ApplicationOctetStreamContentType))
@@ -140,8 +140,8 @@ package object sttp {
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
*/
- def multiPart(name: String, data: InputStream): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: InputStream): Multipart =
+ Multipart(name,
InputStreamBody(data),
contentType = Some(ApplicationOctetStreamContentType))
@@ -149,19 +149,76 @@ package object sttp {
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
*/
- def multiPart(name: String, data: File): MultiPart =
- multiPart(name, data.toPath)
+ def multipart(name: String, data: File): Multipart =
+ multipart(name, data.toPath)
/**
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
*/
- def multiPart(name: String, data: Path): MultiPart =
- MultiPart(name,
+ def multipart(name: String, data: Path): Multipart =
+ Multipart(name,
PathBody(data),
fileName = Some(data.getFileName.toString),
contentType = Some(ApplicationOctetStreamContentType))
+ /**
+ * Encodes the given parameters as form data using `utf-8`.
+ *
+ * Content type will be set to `application/x-www-form-urlencoded`, can be
+ * overridden later using the `contentType` method.
+ */
+ def multipart(name: String, fs: Map[String, String]): Multipart =
+ Multipart(name,
+ RequestBody.paramsToStringBody(fs.toList, Utf8),
+ contentType = Some(ApplicationFormContentType))
+
+ /**
+ * Encodes the given parameters as form data.
+ *
+ * Content type will be set to `application/x-www-form-urlencoded`, can be
+ * overridden later using the `contentType` method.
+ */
+ def multipart(name: String,
+ fs: Map[String, String],
+ encoding: String): Multipart =
+ Multipart(name,
+ RequestBody.paramsToStringBody(fs.toList, encoding),
+ contentType = Some(ApplicationFormContentType))
+
+ /**
+ * Encodes the given parameters as form data using `utf-8`.
+ *
+ * Content type will be set to `application/x-www-form-urlencoded`, can be
+ * overridden later using the `contentType` method.
+ */
+ def multipart(name: String, fs: Seq[(String, String)]): Multipart =
+ Multipart(name,
+ RequestBody.paramsToStringBody(fs, Utf8),
+ contentType = Some(ApplicationFormContentType))
+
+ /**
+ * Encodes the given parameters as form data.
+ *
+ * Content type will be set to `application/x-www-form-urlencoded`, can be
+ * overridden later using the `contentType` method.
+ */
+ def multipart(name: String,
+ fs: Seq[(String, String)],
+ encoding: String): Multipart =
+ Multipart(name,
+ RequestBody.paramsToStringBody(fs, encoding),
+ contentType = Some(ApplicationFormContentType))
+
+ /**
+ * Content type will be set to `application/octet-stream`, can be
+ * overridden later using the `contentType` method.
+ */
+ def multipart[B: BodySerializer](name: String, b: B): Multipart =
+ Multipart(name,
+ implicitly[BodySerializer[B]].apply(b),
+ contentType = Some(ApplicationOctetStreamContentType))
+
// util
private[sttp] def contentTypeWithEncoding(ct: String, enc: String) =
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index 9b64271..b2918fa 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -12,6 +12,7 @@ import akka.http.scaladsl.model.headers.CacheDirectives._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.server.directives.Credentials
+import akka.util.ByteString
import com.softwaremill.sttp.akkahttp.AkkaHttpSttpHandler
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures}
@@ -26,6 +27,7 @@ import com.softwaremill.sttp.okhttp.{
OkHttpFutureClientHandler,
OkHttpSyncClientHandler
}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.language.higherKinds
@@ -136,6 +138,20 @@ class BasicTests
} ~ path("text") {
getFromFile(textFile)
}
+ } ~ pathPrefix("multipart") {
+ entity(as[akka.http.scaladsl.model.Multipart.FormData]) { fd =>
+ complete {
+ fd.parts
+ .mapAsync(1) { p =>
+ val fv = p.entity.dataBytes.runFold(ByteString())(_ ++ _)
+ fv.map(_.utf8String)
+ .map(v =>
+ p.name + "=" + v + p.filename.fold("")(fn => s" ($fn)"))
+ }
+ .runFold(Vector.empty[String])(_ :+ _)
+ .map(v => v.mkString(", "))
+ }
+ }
}
override def port = 51823
@@ -184,6 +200,7 @@ class BasicTests
authTests()
compressionTests()
downloadFileTests()
+// multipartTests()
def parseResponseTests(): Unit = {
name should "parse response as string" in {
@@ -515,6 +532,23 @@ class BasicTests
path.toFile should haveSameContentAs(textFile)
}
}
+
+ def multipartTests(): Unit = {
+ val mp = sttp.post(uri"$endpoint/multipart")
+
+ name should "send a multipart message" in {
+ val req = mp.multipartBody(multipart("p1", "v1"), multipart("p2", "v2"))
+ val resp = req.send().force()
+ resp.body should be("p1=v1, p2=v2")
+ }
+
+ name should "send a multipart message with filenames" in {
+ val req = mp.multipartBody(multipart("p1", "v1").fileName("f1"),
+ multipart("p2", "v2").fileName("f2"))
+ val resp = req.send().force()
+ resp.body should be("p1=v1 (f1), p2=v2 (f2)")
+ }
+ }
}
override protected def afterAll(): Unit = {