diff options
author | adamw <adam@warski.org> | 2017-08-30 13:33:18 +0200 |
---|---|---|
committer | adamw <adam@warski.org> | 2017-08-30 13:33:18 +0200 |
commit | 227e1d6f4433a15dfc3529bac8387540945009cf (patch) | |
tree | 7d108dec35bc99884ae1c95b0294d39f9e3c5b81 /core | |
parent | 94b9204973983b4e1620d7cdfd188ca9fd1c01ca (diff) | |
download | sttp-227e1d6f4433a15dfc3529bac8387540945009cf.tar.gz sttp-227e1d6f4433a15dfc3529bac8387540945009cf.tar.bz2 sttp-227e1d6f4433a15dfc3529bac8387540945009cf.zip |
Implementing multi-part uploads for the http url connection backend
Diffstat (limited to 'core')
3 files changed, 143 insertions, 17 deletions
diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala index 82d1c54..ed95b07 100644 --- a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala @@ -5,6 +5,7 @@ import java.net.{HttpURLConnection, URL} import java.nio.channels.Channels import java.nio.charset.CharacterCodingException import java.nio.file.Files +import java.util.concurrent.ThreadLocalRandom import java.util.zip.{GZIPInputStream, InflaterInputStream} import scala.annotation.tailrec @@ -18,7 +19,17 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] { c.setRequestMethod(r.method.m) r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) } c.setDoInput(true) - setBody(r.body, c) + + if (r.body != NoBody) { + c.setDoOutput(true) + // we need to take care to: + // (1) only call getOutputStream after the headers are set + // (2) call it ony once + writeBody(r.body, c).foreach { os => + os.flush() + os.close() + } + } try { val is = c.getInputStream @@ -33,35 +44,139 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] { override def responseMonad: MonadError[Id] = IdMonad - private def setBody(body: RequestBody[Nothing], c: HttpURLConnection): Unit = { - if (body != NoBody) c.setDoOutput(true) - + private def writeBody(body: RequestBody[Nothing], + c: HttpURLConnection): Option[OutputStream] = { body match { - case NoBody => // skip + case NoBody => + // skip + None + + case b: BasicRequestBody => + val os = c.getOutputStream + writeBasicBody(b, os) + Some(os) + + case StreamBody(s) => + // we have an instance of nothing - everything's possible! + None + case mp: MultipartBody => + setMultipartBody(mp, c) + } + } + + private def writeBasicBody(body: BasicRequestBody, os: OutputStream): Unit = { + body match { case StringBody(b, encoding, _) => - val writer = new OutputStreamWriter(c.getOutputStream, encoding) - try writer.write(b) - finally writer.close() + val writer = new OutputStreamWriter(os, encoding) + writer.write(b) + // don't close - as this will close the underlying OS and cause errors + // with multi-part + writer.flush() case ByteArrayBody(b, _) => - c.getOutputStream.write(b) + os.write(b) case ByteBufferBody(b, _) => - val channel = Channels.newChannel(c.getOutputStream) - try channel.write(b) - finally channel.close() + val channel = Channels.newChannel(os) + channel.write(b) case InputStreamBody(b, _) => - transfer(b, c.getOutputStream) + transfer(b, os) case PathBody(b, _) => - Files.copy(b, c.getOutputStream) + Files.copy(b, os) + } + } - case StreamBody(s) => - // we have an instance of nothing - everything's possible! - s + private val BoundaryChars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789".toCharArray + + private def setMultipartBody(mp: MultipartBody, + c: HttpURLConnection): Option[OutputStream] = { + val boundary = { + val tlr = ThreadLocalRandom.current() + List + .fill(32)(BoundaryChars(tlr.nextInt(BoundaryChars.length))) + .mkString + } + + // inspired by: https://github.com/scalaj/scalaj-http/blob/master/src/main/scala/scalaj/http/Http.scala#L542 + val partsWithHeaders = mp.parts.map { p => + val contentDisposition = + s"$ContentDispositionHeader: ${p.contentDispositionHeaderValue}" + val contentTypeHeader = + p.contentType.map(ct => s"$ContentTypeHeader: $ct") + val otherHeaders = p.additionalHeaders.map(h => s"${h._1}: ${h._2}") + val allHeaders = List(contentDisposition) ++ contentTypeHeader.toList ++ otherHeaders + (allHeaders.mkString(CrLf), p) } + + val dashes = "--" + + val dashesLen = dashes.length.toLong + val crLfLen = CrLf.length.toLong + val boundaryLen = boundary.length.toLong + val finalBoundaryLen = dashesLen + boundaryLen + dashesLen + crLfLen + + // https://stackoverflow.com/questions/31406022/how-is-an-http-multipart-content-length-header-value-calculated + val contentLength = partsWithHeaders + .map { + case (headers, p) => + val bodyLen: Option[Long] = p.body match { + case StringBody(b, encoding, _) => + Some(b.getBytes(encoding).length.toLong) + case ByteArrayBody(b, _) => Some(b.length.toLong) + case ByteBufferBody(b, _) => None + case InputStreamBody(b, _) => None + case PathBody(b, _) => Some(b.toFile.length()) + } + + val headersLen = headers.getBytes(Iso88591).length + + bodyLen.map(bl => + dashesLen + boundaryLen + crLfLen + headersLen + crLfLen + crLfLen + bl + crLfLen) + } + .foldLeft(Option(finalBoundaryLen)) { + case (Some(acc), Some(l)) => Some(acc + l) + case _ => None + } + + c.setRequestProperty(ContentTypeHeader, + "multipart/form-data; boundary=" + boundary) + + contentLength.foreach { cl => + c.setFixedLengthStreamingMode(cl) + c.setRequestProperty(ContentLengthHeader, cl.toString) + } + + var total = 0L + + val os = c.getOutputStream + def writeMeta(s: String): Unit = { + os.write(s.getBytes(Iso88591)) + total += s.getBytes(Iso88591).length.toLong + } + + partsWithHeaders.foreach { + case (headers, p) => + writeMeta(dashes) + writeMeta(boundary) + writeMeta(CrLf) + writeMeta(headers) + writeMeta(CrLf) + writeMeta(CrLf) + writeBasicBody(p.body, os) + writeMeta(CrLf) + } + + // final boundary + writeMeta(dashes) + writeMeta(boundary) + writeMeta(dashes) + writeMeta(CrLf) + + Some(os) } private def readResponse[T]( diff --git a/core/src/main/scala/com/softwaremill/sttp/Multipart.scala b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala index 86a4494..4b9464d 100644 --- a/core/src/main/scala/com/softwaremill/sttp/Multipart.scala +++ b/core/src/main/scala/com/softwaremill/sttp/Multipart.scala @@ -14,4 +14,12 @@ case class Multipart(name: String, def contentType(v: String): Multipart = copy(contentType = Some(v)) def header(k: String, v: String): Multipart = copy(additionalHeaders = additionalHeaders + (k -> v)) + + private[sttp] def contentDispositionHeaderValue: String = { + def encodeHeaderValue(s: String): String = + new String(s.getBytes(Utf8), Iso88591) + + s"""form-data; name="${encodeHeaderValue(name)}"""" + + fileName.fold("")(fn => s"""; filename="${encodeHeaderValue(fn)}"""") + } } diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index 60ef6f5..a0a6c57 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -36,7 +36,10 @@ package object sttp { private[sttp] val ProxyAuthorizationHeader = "Proxy-Authorization" private[sttp] val AcceptEncodingHeader = "Accept-Encoding" private[sttp] val ContentEncodingHeader = "Content-Encoding" + private[sttp] val ContentDispositionHeader = "Content-Disposition" private[sttp] val Utf8 = "utf-8" + private[sttp] val Iso88591 = "iso-8859-1" + private[sttp] val CrLf = "\r\n" private[sttp] val ApplicationOctetStreamContentType = "application/octet-stream" |