aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoradamw <adam@warski.org>2017-08-30 13:33:18 +0200
committeradamw <adam@warski.org>2017-08-30 13:33:18 +0200
commit227e1d6f4433a15dfc3529bac8387540945009cf (patch)
tree7d108dec35bc99884ae1c95b0294d39f9e3c5b81
parent94b9204973983b4e1620d7cdfd188ca9fd1c01ca (diff)
downloadsttp-227e1d6f4433a15dfc3529bac8387540945009cf.tar.gz
sttp-227e1d6f4433a15dfc3529bac8387540945009cf.tar.bz2
sttp-227e1d6f4433a15dfc3529bac8387540945009cf.zip
Implementing multi-part uploads for the http url connection backend
-rw-r--r--README.md1
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala149
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/Multipart.scala8
-rw-r--r--core/src/main/scala/com/softwaremill/sttp/package.scala3
-rw-r--r--okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala6
-rw-r--r--tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala2
6 files changed, 145 insertions, 24 deletions
diff --git a/README.md b/README.md
index 13606de..84fe157 100644
--- a/README.md
+++ b/README.md
@@ -448,7 +448,6 @@ There are two type aliases for the request template that are used:
## TODO
-* multi-part uploads
* scalaz/fs2 streaming
* proxy support
* connection options, SSL
diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala
index 82d1c54..ed95b07 100644
--- a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala
@@ -5,6 +5,7 @@ import java.net.{HttpURLConnection, URL}
import java.nio.channels.Channels
import java.nio.charset.CharacterCodingException
import java.nio.file.Files
+import java.util.concurrent.ThreadLocalRandom
import java.util.zip.{GZIPInputStream, InflaterInputStream}
import scala.annotation.tailrec
@@ -18,7 +19,17 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] {
c.setRequestMethod(r.method.m)
r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) }
c.setDoInput(true)
- setBody(r.body, c)
+
+ if (r.body != NoBody) {
+ c.setDoOutput(true)
+ // we need to take care to:
+ // (1) only call getOutputStream after the headers are set
+ // (2) call it ony once
+ writeBody(r.body, c).foreach { os =>
+ os.flush()
+ os.close()
+ }
+ }
try {
val is = c.getInputStream
@@ -33,35 +44,139 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] {
override def responseMonad: MonadError[Id] = IdMonad
- private def setBody(body: RequestBody[Nothing], c: HttpURLConnection): Unit = {
- if (body != NoBody) c.setDoOutput(true)
-
+ private def writeBody(body: RequestBody[Nothing],
+ c: HttpURLConnection): Option[OutputStream] = {
body match {
- case NoBody => // skip
+ case NoBody =>
+ // skip
+ None
+
+ case b: BasicRequestBody =>
+ val os = c.getOutputStream
+ writeBasicBody(b, os)
+ Some(os)
+
+ case StreamBody(s) =>
+ // we have an instance of nothing - everything's possible!
+ None
+ case mp: MultipartBody =>
+ setMultipartBody(mp, c)
+ }
+ }
+
+ private def writeBasicBody(body: BasicRequestBody, os: OutputStream): Unit = {
+ body match {
case StringBody(b, encoding, _) =>
- val writer = new OutputStreamWriter(c.getOutputStream, encoding)
- try writer.write(b)
- finally writer.close()
+ val writer = new OutputStreamWriter(os, encoding)
+ writer.write(b)
+ // don't close - as this will close the underlying OS and cause errors
+ // with multi-part
+ writer.flush()
case ByteArrayBody(b, _) =>
- c.getOutputStream.write(b)
+ os.write(b)
case ByteBufferBody(b, _) =>
- val channel = Channels.newChannel(c.getOutputStream)
- try channel.write(b)
- finally channel.close()
+ val channel = Channels.newChannel(os)
+ channel.write(b)
case InputStreamBody(b, _) =>
- transfer(b, c.getOutputStream)
+ transfer(b, os)
case PathBody(b, _) =>
- Files.copy(b, c.getOutputStream)
+ Files.copy(b, os)
+ }
+ }
- case StreamBody(s) =>
- // we have an instance of nothing - everything's possible!
- s
+ private val BoundaryChars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".toCharArray
+
+ private def setMultipartBody(mp: MultipartBody,
+ c: HttpURLConnection): Option[OutputStream] = {
+ val boundary = {
+ val tlr = ThreadLocalRandom.current()
+ List
+ .fill(32)(BoundaryChars(tlr.nextInt(BoundaryChars.length)))
+ .mkString
+ }
+
+ // inspired by: https://github.com/scalaj/scalaj-http/blob/master/src/main/scala/scalaj/http/Http.scala#L542
+ val partsWithHeaders = mp.parts.map { p =>
+ val contentDisposition =
+ s"$ContentDispositionHeader: ${p.contentDispositionHeaderValue}"
+ val contentTypeHeader =
+ p.contentType.map(ct => s"$ContentTypeHeader: $ct")
+ val otherHeaders = p.additionalHeaders.map(h => s"${h._1}: ${h._2}")
+ val allHeaders = List(contentDisposition) ++ contentTypeHeader.toList ++ otherHeaders
+ (allHeaders.mkString(CrLf), p)
}
+
+ val dashes = "--"
+
+ val dashesLen = dashes.length.toLong
+ val crLfLen = CrLf.length.toLong
+ val boundaryLen = boundary.length.toLong
+ val finalBoundaryLen = dashesLen + boundaryLen + dashesLen + crLfLen
+
+ // https://stackoverflow.com/questions/31406022/how-is-an-http-multipart-content-length-header-value-calculated
+ val contentLength = partsWithHeaders
+ .map {
+ case (headers, p) =>
+ val bodyLen: Option[Long] = p.body match {
+ case StringBody(b, encoding, _) =>
+ Some(b.getBytes(encoding).length.toLong)
+ case ByteArrayBody(b, _) => Some(b.length.toLong)
+ case ByteBufferBody(b, _) => None
+ case InputStreamBody(b, _) => None
+ case PathBody(b, _) => Some(b.toFile.length())
+ }
+
+ val headersLen = headers.getBytes(Iso88591).length
+
+ bodyLen.map(bl =>
+ dashesLen + boundaryLen + crLfLen + headersLen + crLfLen + crLfLen + bl + crLfLen)
+ }
+ .foldLeft(Option(finalBoundaryLen)) {
+ case (Some(acc), Some(l)) => Some(acc + l)
+ case _ => None
+ }
+
+ c.setRequestProperty(ContentTypeHeader,
+ "multipart/form-data; boundary=" + boundary)
+
+ contentLength.foreach { cl =>
+ c.setFixedLengthStreamingMode(cl)
+ c.setRequestProperty(ContentLengthHeader, cl.toString)
+ }
+
+ var total = 0L
+
+ val os = c.getOutputStream
+ def writeMeta(s: String): Unit = {
+ os.write(s.getBytes(Iso88591))
+ total += s.getBytes(Iso88591).length.toLong
+ }
+
+ partsWithHeaders.foreach {
+ case (headers, p) =>
+ writeMeta(dashes)
+ writeMeta(boundary)
+ writeMeta(CrLf)
+ writeMeta(headers)
+ writeMeta(CrLf)
+ writeMeta(CrLf)
+ writeBasicBody(p.body, os)
+ writeMeta(CrLf)
+ }
+
+ // final boundary
+ writeMeta(dashes)
+ writeMeta(boundary)
+ writeMeta(dashes)
+ writeMeta(CrLf)
+
+ Some(os)
}
private def readResponse[T](
diff --git a/core/src/main/scala/com/softwaremill/sttp/Multipart.scala b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala
index 86a4494..4b9464d 100644
--- a/core/src/main/scala/com/softwaremill/sttp/Multipart.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala
@@ -14,4 +14,12 @@ case class Multipart(name: String,
def contentType(v: String): Multipart = copy(contentType = Some(v))
def header(k: String, v: String): Multipart =
copy(additionalHeaders = additionalHeaders + (k -> v))
+
+ private[sttp] def contentDispositionHeaderValue: String = {
+ def encodeHeaderValue(s: String): String =
+ new String(s.getBytes(Utf8), Iso88591)
+
+ s"""form-data; name="${encodeHeaderValue(name)}"""" +
+ fileName.fold("")(fn => s"""; filename="${encodeHeaderValue(fn)}"""")
+ }
}
diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala
index 60ef6f5..a0a6c57 100644
--- a/core/src/main/scala/com/softwaremill/sttp/package.scala
+++ b/core/src/main/scala/com/softwaremill/sttp/package.scala
@@ -36,7 +36,10 @@ package object sttp {
private[sttp] val ProxyAuthorizationHeader = "Proxy-Authorization"
private[sttp] val AcceptEncodingHeader = "Accept-Encoding"
private[sttp] val ContentEncodingHeader = "Content-Encoding"
+ private[sttp] val ContentDispositionHeader = "Content-Disposition"
private[sttp] val Utf8 = "utf-8"
+ private[sttp] val Iso88591 = "iso-8859-1"
+ private[sttp] val CrLf = "\r\n"
private[sttp] val ApplicationOctetStreamContentType =
"application/octet-stream"
diff --git a/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
index 6bef4fa..7388bb1 100644
--- a/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
+++ b/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
@@ -1,7 +1,6 @@
package com.softwaremill.sttp.okhttp
import java.io.IOException
-import java.net.URLEncoder
import java.nio.charset.Charset
import com.softwaremill.sttp._
@@ -77,10 +76,7 @@ abstract class OkHttpClientHandler[R[_], S](client: OkHttpClient)
private def addMultipart(builder: OkHttpMultipartBody.Builder,
mp: Multipart): Unit = {
- val disposition = s"""form-data; name="${URLEncoder.encode(mp.name, Utf8)}"""" +
- mp.fileName.fold("")(fn =>
- s"""; filename="${URLEncoder.encode(fn, Utf8)}"""")
- val allHeaders = mp.additionalHeaders + ("Content-Disposition" -> disposition)
+ val allHeaders = mp.additionalHeaders + (ContentDispositionHeader -> mp.contentDispositionHeaderValue)
val headers = Headers.of(allHeaders.asJava)
bodyToOkHttp(mp.body).foreach(builder.addPart(headers, _))
diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
index 9b686dc..aeee5ed 100644
--- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
+++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala
@@ -200,7 +200,7 @@ class BasicTests
authTests()
compressionTests()
downloadFileTests()
-// multipartTests()
+ multipartTests()
def parseResponseTests(): Unit = {
name should "parse response as string" in {