From d0f9a7f9ece12f18852660b42d65d76a124de53e Mon Sep 17 00:00:00 2001 From: adamw Date: Fri, 30 Jun 2017 14:52:18 +0200 Subject: Content type --- .../sttp/HttpConnectionSttpHandler.scala | 4 +- .../com/softwaremill/sttp/model/package.scala | 3 +- .../main/scala/com/softwaremill/sttp/package.scala | 120 +++++++++++++++++---- 3 files changed, 103 insertions(+), 24 deletions(-) (limited to 'core/src') diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala index 7ed7cc7..362bdab 100644 --- a/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala @@ -43,8 +43,8 @@ object HttpConnectionSttpHandler extends SttpHandler[Id] { body match { case NoBody => // skip - case StringBody(b) => - val writer = new OutputStreamWriter(c.getOutputStream) + case StringBody(b, encoding) => + val writer = new OutputStreamWriter(c.getOutputStream, encoding) try writer.write(b) finally writer.close() case ByteArrayBody(b) => diff --git a/core/src/main/scala/com/softwaremill/sttp/model/package.scala b/core/src/main/scala/com/softwaremill/sttp/model/package.scala index df9ef1b..1541398 100644 --- a/core/src/main/scala/com/softwaremill/sttp/model/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/model/package.scala @@ -27,7 +27,7 @@ package object model { case class SerializableBody[T](f: BodySerializer[T], t: T) extends RequestBody sealed trait BasicRequestBody extends RequestBody - case class StringBody(s: String) extends BasicRequestBody + case class StringBody(s: String, encoding: String) extends BasicRequestBody case class ByteArrayBody(b: Array[Byte]) extends BasicRequestBody case class ByteBufferBody(b: ByteBuffer) extends BasicRequestBody case class InputStreamBody(b: InputStream) extends BasicRequestBody @@ -38,6 +38,7 @@ package object model { object IgnoreResponse extends ResponseAs[Unit] case class ResponseAsString(encoding: String) extends ResponseAs[String] object ResponseAsByteArray extends ResponseAs[Array[Byte]] + // response as params case class ResponseAsStream[S]() } diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index c787643..67be14a 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -9,6 +9,7 @@ import com.softwaremill.sttp.model._ import scala.annotation.implicitNotFound import scala.language.higherKinds +import scala.collection.immutable.Seq package object sttp { /* @@ -75,21 +76,48 @@ package object sttp { def header(k: String, v: String): MultiPart = copy(additionalHeaders = additionalHeaders + (k -> v)) } - def multiPart(name: String, data: String): MultiPart = MultiPart(name, StringBody(data)) - def multiPart(name: String, data: Array[Byte]): MultiPart = MultiPart(name, ByteArrayBody(data)) - def multiPart(name: String, data: ByteBuffer): MultiPart = MultiPart(name, ByteBufferBody(data)) - def multiPart(name: String, data: InputStream): MultiPart = MultiPart(name, InputStreamBody(data)) - // mandatory content type? - def multiPart(name: String, data: File): MultiPart = MultiPart(name, FileBody(data), fileName = Some(data.getName)) - def multiPart(name: String, data: Path): MultiPart = MultiPart(name, PathBody(data), fileName = Some(data.getFileName.toString)) + /** + * Content type will be set to `text/plain` with `utf-8` encoding, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: String): MultiPart = + MultiPart(name, StringBody(data, Utf8), contentType = Some(contentTypeWithEncoding(TextPlainContentType, Utf8))) + /** + * Content type will be set to `text/plain` with `utf-8` encoding, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: String, encoding: String): MultiPart = + MultiPart(name, StringBody(data, encoding), contentType = Some(contentTypeWithEncoding(TextPlainContentType, Utf8))) + /** + * Content type will be set to `application/octet-stream`, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: Array[Byte]): MultiPart = + MultiPart(name, ByteArrayBody(data), contentType = Some(ApplicationOctetStreamContentType)) + /** + * Content type will be set to `application/octet-stream`, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: ByteBuffer): MultiPart = + MultiPart(name, ByteBufferBody(data), contentType = Some(ApplicationOctetStreamContentType)) + /** + * Content type will be set to `application/octet-stream`, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: InputStream): MultiPart = + MultiPart(name, InputStreamBody(data), contentType = Some(ApplicationOctetStreamContentType)) + /** + * Content type will be set to `application/octet-stream`, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: File): MultiPart = + MultiPart(name, FileBody(data), fileName = Some(data.getName), contentType = Some(ApplicationOctetStreamContentType)) + /** + * Content type will be set to `application/octet-stream`, can be overridden later using the `contentType` method. + */ + def multiPart(name: String, data: Path): MultiPart = + MultiPart(name, PathBody(data), fileName = Some(data.getFileName.toString), contentType = Some(ApplicationOctetStreamContentType)) case class RequestTemplate[U[_]]( method: U[Method], uri: U[URI], body: RequestBody, - headers: Map[String, String] + headers: Seq[(String, String)] ) { - def get(uri: URI): Request = this.copy[Id](uri = uri, method = Method.GET) def head(uri: URI): Request = this.copy[Id](uri = uri, method = Method.HEAD) def post(uri: URI): Request = this.copy[Id](uri = uri, method = Method.POST) @@ -98,18 +126,58 @@ package object sttp { def options(uri: URI): Request = this.copy[Id](uri = uri, method = Method.OPTIONS) def patch(uri: URI): Request = this.copy[Id](uri = uri, method = Method.PATCH) - def header(k: String, v: String): RequestTemplate[U] = this.copy(headers = headers + (k -> v)) + def contentType(ct: String): RequestTemplate[U] = + header(ContentTypeHeader, ct, replaceExisting = true) + def contentType(ct: String, encoding: String): RequestTemplate[U] = + header(ContentTypeHeader, contentTypeWithEncoding(ct, encoding), replaceExisting = true) + def header(k: String, v: String, replaceExisting: Boolean = false): RequestTemplate[U] = { + val kLower = k.toLowerCase + val current = if (replaceExisting) headers.filterNot(_._1.toLowerCase.contains(kLower)) else headers + this.copy(headers = current :+ (k -> v)) + } - // automatically set the content type? - unless specified - def data(b: String): RequestTemplate[U] = this.copy(body = StringBody(b)) - def data(b: Array[Byte]): RequestTemplate[U] = this.copy(body = ByteArrayBody(b)) - def data(b: ByteBuffer): RequestTemplate[U] = this.copy(body = ByteBufferBody(b)) - def data(b: InputStream): RequestTemplate[U] = this.copy(body = InputStreamBody(b)) - // mandatory content type? - def data(b: File): RequestTemplate[U] = this.copy(body = FileBody(b)) - def data(b: Path): RequestTemplate[U] = this.copy(body = PathBody(b)) - def data[T: BodySerializer](b: T): RequestTemplate[U] = this.copy(body = SerializableBody(implicitly[BodySerializer[T]], b)) - // add serializable / deserializable? + /** + * If content type is not specified, will be set to `text/plain` with `utf-8` encoding. + */ + def data(b: String): RequestTemplate[U] = data(b, Utf8) + /** + * If content type is not specified, will be set to `text/plain` with the given encoding. + */ + def data(b: String, encoding: String): RequestTemplate[U] = + setContentTypeIfMissing(contentTypeWithEncoding(TextPlainContentType, encoding)).copy(body = StringBody(b, encoding)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data(b: Array[Byte]): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = ByteArrayBody(b)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data(b: ByteBuffer): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = ByteBufferBody(b)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data(b: InputStream): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = InputStreamBody(b)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data(b: File): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = FileBody(b)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data(b: Path): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = PathBody(b)) + /** + * If content type is not specified, will be set to `application/octet-stream`. + */ + def data[T: BodySerializer](b: T): RequestTemplate[U] = + setContentTypeIfMissing(ApplicationOctetStreamContentType).copy(body = SerializableBody(implicitly[BodySerializer[T]], b)) + + private def hasContentType: Boolean = headers.exists(_._1.toLowerCase.contains(ContentTypeHeader)) + private def setContentTypeIfMissing(ct: String): RequestTemplate[U] = if (hasContentType) this else contentType(ct) //def formData(fs: Map[String, Seq[String]]): RequestTemplate[U] = ??? def formData(fs: Map[String, String]): RequestTemplate[U] = ??? @@ -131,7 +199,7 @@ package object sttp { } object RequestTemplate { - val empty: RequestTemplate[Empty] = RequestTemplate[Empty](None, None, NoBody, Map.empty) + val empty: RequestTemplate[Empty] = RequestTemplate[Empty](None, None, NoBody, Vector()) } type PartialRequest = RequestTemplate[Empty] @@ -141,4 +209,14 @@ package object sttp { private type IsRequest[U[_]] = RequestTemplate[U] =:= Request val sttp: RequestTemplate[Empty] = RequestTemplate.empty + + private val ContentTypeHeader = "content-type" + private val Utf8 = "utf-8" + + private val ApplicationOctetStreamContentType = "application/octet-stream" + private val ApplicationFormContentType = "application/x-www-form-urlencoded" + private val TextPlainContentType = "text/plain" + private val MultipartFormDataContentType = "multipart/form-data" + + private def contentTypeWithEncoding(ct: String, enc: String) = s"$ct; charset=$enc" } -- cgit v1.2.3