aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala45
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/package.scala4
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala10
3 files changed, 59 insertions, 0 deletions
diff --git a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
index c8f0144..b29eb97 100644
--- a/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
+++ b/async-http-client-handler/src/main/scala/com/softwaremill/sttp/asynchttpclient/AsyncHttpClientHandler.scala
@@ -1,5 +1,6 @@
package com.softwaremill.sttp.asynchttpclient
+import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer
import java.nio.charset.Charset
@@ -7,6 +8,11 @@ import com.softwaremill.sttp.ResponseAs.EagerResponseHandler
import com.softwaremill.sttp._
import org.asynchttpclient.AsyncHandler.State
import org.asynchttpclient.handler.StreamedAsyncHandler
+import org.asynchttpclient.request.body.multipart.{
+ ByteArrayPart,
+ FilePart,
+ StringPart
+}
import org.asynchttpclient.{
AsyncCompletionHandler,
AsyncHandler,
@@ -14,6 +20,7 @@ import org.asynchttpclient.{
HttpResponseBodyPart,
HttpResponseHeaders,
HttpResponseStatus,
+ Param,
RequestBuilder,
Request => AsyncRequest,
Response => AsyncResponse
@@ -172,7 +179,45 @@ abstract class AsyncHttpClientHandler[R[_], S](asyncHttpClient: AsyncHttpClient,
.map(_._2.toLong)
.getOrElse(-1L)
rb.setBody(streamBodyToPublisher(s), cl)
+
+ case MultipartBody(ps) =>
+ ps.foreach(addMultipartBody(rb, _))
+ }
+ }
+
+ private def addMultipartBody(rb: RequestBuilder, mp: Multipart): Unit = {
+ // async http client only supports setting file names on file parts. To
+ // set a file name on an arbitrary part we have to use a small "work
+ // around", combining the file name with the name (surrounding quotes
+ // are added by ahc).
+ def nameWithFilename = mp.fileName.fold(mp.name) { fn =>
+ s"""${mp.name}"; filename="$fn"""
+ }
+
+ val bodyPart = mp.body match {
+ case StringBody(b, encoding, _) =>
+ new StringPart(nameWithFilename,
+ b,
+ mp.contentType.getOrElse(TextPlainContentType),
+ Charset.forName(encoding))
+ case ByteArrayBody(b, _) =>
+ new ByteArrayPart(nameWithFilename, b)
+ case ByteBufferBody(b, _) =>
+ new ByteArrayPart(nameWithFilename, b.array())
+ case InputStreamBody(b, _) =>
+ // sadly async http client only supports parts that are strings,
+ // byte arrays or files
+ val baos = new ByteArrayOutputStream()
+ transfer(b, baos)
+ new ByteArrayPart(nameWithFilename, baos.toByteArray)
+ case PathBody(b, _) =>
+ new FilePart(mp.name, b.toFile, null, null, mp.fileName.orNull)
}
+
+ bodyPart.setCustomHeaders(
+ mp.additionalHeaders.map(h => new Param(h._1, h._2)).toList.asJava)
+
+ rb.addBodyPart(bodyPart)
}
private def readEagerResponse[T](
diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala
index 69f20f0..60ef6f5 100644
--- a/core/src/main/scala/com/softwaremill/sttp/package.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/package.scala
@@ -146,6 +146,8 @@ package object sttp {
/**
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
+ *
+ * File name will be set to the name of the file.
*/
def multipart(name: String, data: File): Multipart =
multipart(name, data.toPath)
@@ -153,6 +155,8 @@ package object sttp {
/**
* Content type will be set to `application/octet-stream`, can be overridden
* later using the `contentType` method.
+ *
+ * File name will be set to the name of the file.
*/
def multipart(name: String, data: Path): Multipart =
Multipart(name,
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index b2918fa..9b686dc 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -548,6 +548,16 @@ class BasicTests
val resp = req.send().force()
resp.body should be("p1=v1 (f1), p2=v2 (f2)")
}
+
+ name should "send a multipart message with a file" in {
+ val f = File.newTemporaryFile().write(testBody)
+ try {
+ val req =
+ mp.multipartBody(multipart("p1", f.toJava), multipart("p2", "v2"))
+ val resp = req.send().force()
+ resp.body should be(s"p1=$testBody (${f.name}), p2=v2")
+ } finally f.delete()
+ }
}
}