aboutsummaryrefslogtreecommitdiff
path: root/okhttp-backend
diff options
context:
space:
mode:
authoradamw <adam@warski.org>2017-09-14 11:03:21 +0100
committeradamw <adam@warski.org>2017-09-14 11:03:21 +0100
commitfbc71ee712635ed64c50ca694735a84ec794eb11 (patch)
treebf1dd7335306b7f320262d45d0d5b6d02f5a0b27 /okhttp-backend
parenta971d409cb1063a2089d936abf3d3ab70bbbabb6 (diff)
downloadsttp-fbc71ee712635ed64c50ca694735a84ec794eb11.tar.gz
sttp-fbc71ee712635ed64c50ca694735a84ec794eb11.tar.bz2
sttp-fbc71ee712635ed64c50ca694735a84ec794eb11.zip
Renaming "handler" to "backend"
Diffstat (limited to 'okhttp-backend')
-rw-r--r--okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala122
-rw-r--r--okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala255
2 files changed, 377 insertions, 0 deletions
diff --git a/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala b/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala
new file mode 100644
index 0000000..4b24e65
--- /dev/null
+++ b/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala
@@ -0,0 +1,122 @@
+package com.softwaremill.sttp.okhttp.monix
+
+import java.nio.ByteBuffer
+import java.util.concurrent.ArrayBlockingQueue
+
+import com.softwaremill.sttp.{SttpBackend, _}
+import com.softwaremill.sttp.okhttp.{OkHttpAsyncBackend, OkHttpBackend}
+import monix.eval.Task
+import monix.execution.Ack.Continue
+import monix.execution.{Ack, Cancelable, Scheduler}
+import monix.reactive.Observable
+import monix.reactive.observers.Subscriber
+import okhttp3.{MediaType, OkHttpClient, RequestBody => OkHttpRequestBody}
+import okio.BufferedSink
+
+import scala.concurrent.Future
+import scala.concurrent.duration.FiniteDuration
+import scala.util.{Failure, Success, Try}
+
+class OkHttpMonixBackend private (client: OkHttpClient, closeClient: Boolean)(
+ implicit s: Scheduler)
+ extends OkHttpAsyncBackend[Task, Observable[ByteBuffer]](client,
+ TaskMonad,
+ closeClient) {
+
+ override def streamToRequestBody(
+ stream: Observable[ByteBuffer]): Option[OkHttpRequestBody] =
+ Some(new OkHttpRequestBody() {
+ override def writeTo(sink: BufferedSink): Unit =
+ toIterable(stream) map (_.array()) foreach sink.write
+ override def contentType(): MediaType = null
+ })
+
+ override def responseBodyToStream(
+ res: okhttp3.Response): Try[Observable[ByteBuffer]] =
+ Success(
+ Observable
+ .fromInputStream(res.body().byteStream())
+ .map(ByteBuffer.wrap)
+ .doAfterTerminate(_ => res.close()))
+
+ private def toIterable[T](observable: Observable[T])(
+ implicit s: Scheduler): Iterable[T] =
+ new Iterable[T] {
+ override def iterator: Iterator[T] = new Iterator[T] {
+ case object Completed extends Exception
+
+ val blockingQueue = new ArrayBlockingQueue[Either[Throwable, T]](1)
+
+ observable.executeWithFork.subscribe(new Subscriber[T] {
+ override implicit def scheduler: Scheduler = s
+
+ override def onError(ex: Throwable): Unit = {
+ blockingQueue.put(Left(ex))
+ }
+
+ override def onComplete(): Unit = {
+ blockingQueue.put(Left(Completed))
+ }
+
+ override def onNext(elem: T): Future[Ack] = {
+ blockingQueue.put(Right(elem))
+ Continue
+ }
+ })
+
+ var value: T = _
+
+ override def hasNext: Boolean =
+ blockingQueue.take() match {
+ case Left(Completed) => false
+ case Right(elem) =>
+ value = elem
+ true
+ case Left(ex) => throw ex
+ }
+
+ override def next(): T = value
+ }
+ }
+}
+
+object OkHttpMonixBackend {
+ private def apply(client: OkHttpClient, closeClient: Boolean)(
+ implicit s: Scheduler): SttpBackend[Task, Observable[ByteBuffer]] =
+ new FollowRedirectsBackend(new OkHttpMonixBackend(client, closeClient)(s))
+
+ def apply(connectionTimeout: FiniteDuration =
+ SttpBackend.DefaultConnectionTimeout)(
+ implicit s: Scheduler = Scheduler.Implicits.global)
+ : SttpBackend[Task, Observable[ByteBuffer]] =
+ OkHttpMonixBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis,
+ connectionTimeout.toMillis),
+ closeClient = true)(s)
+
+ def usingClient(client: OkHttpClient)(implicit s: Scheduler =
+ Scheduler.Implicits.global)
+ : SttpBackend[Task, Observable[ByteBuffer]] =
+ OkHttpMonixBackend(client, closeClient = false)(s)
+}
+
+private[monix] object TaskMonad extends MonadAsyncError[Task] {
+ override def unit[T](t: T): Task[T] = Task.now(t)
+
+ override def map[T, T2](fa: Task[T])(f: (T) => T2): Task[T2] = fa.map(f)
+
+ override def flatMap[T, T2](fa: Task[T])(f: (T) => Task[T2]): Task[T2] =
+ fa.flatMap(f)
+
+ override def async[T](
+ register: ((Either[Throwable, T]) => Unit) => Unit): Task[T] =
+ Task.async { (_, cb) =>
+ register {
+ case Left(t) => cb(Failure(t))
+ case Right(t) => cb(Success(t))
+ }
+
+ Cancelable.empty
+ }
+
+ override def error[T](t: Throwable): Task[T] = Task.raiseError(t)
+}
diff --git a/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
new file mode 100644
index 0000000..ead1bde
--- /dev/null
+++ b/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala
@@ -0,0 +1,255 @@
+package com.softwaremill.sttp.okhttp
+
+import java.io.IOException
+import java.nio.charset.Charset
+import java.util.concurrent.TimeUnit
+
+import com.softwaremill.sttp._
+import ResponseAs.EagerResponseHandler
+import okhttp3.internal.http.HttpMethod
+import okhttp3.{
+ Call,
+ Callback,
+ Headers,
+ MediaType,
+ OkHttpClient,
+ MultipartBody => OkHttpMultipartBody,
+ Request => OkHttpRequest,
+ RequestBody => OkHttpRequestBody,
+ Response => OkHttpResponse
+}
+import okio.{BufferedSink, Okio}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration.FiniteDuration
+import scala.concurrent.{ExecutionContext, Future}
+import scala.language.higherKinds
+import scala.util.{Failure, Try}
+
+abstract class OkHttpBackend[R[_], S](client: OkHttpClient,
+ closeClient: Boolean)
+ extends SttpBackend[R, S] {
+ private[okhttp] def convertRequest[T](request: Request[T, S]): OkHttpRequest = {
+ val builder = new OkHttpRequest.Builder()
+ .url(request.uri.toString)
+
+ val body = bodyToOkHttp(request.body)
+ builder.method(request.method.m, body.getOrElse {
+ if (HttpMethod.requiresRequestBody(request.method.m))
+ OkHttpRequestBody.create(null, "")
+ else null
+ })
+
+ //OkHttp support automatic gzip compression
+ request.headers
+ .filter(_._1.equalsIgnoreCase(AcceptEncodingHeader) == false)
+ .foreach {
+ case (name, value) => builder.addHeader(name, value)
+ }
+
+ builder.build()
+ }
+
+ private def bodyToOkHttp[T](body: RequestBody[S]): Option[OkHttpRequestBody] = {
+ body match {
+ case NoBody => None
+ case StringBody(b, _, _) =>
+ Some(OkHttpRequestBody.create(null, b))
+ case ByteArrayBody(b, _) =>
+ Some(OkHttpRequestBody.create(null, b))
+ case ByteBufferBody(b, _) =>
+ Some(OkHttpRequestBody.create(null, b.array()))
+ case InputStreamBody(b, _) =>
+ Some(new OkHttpRequestBody() {
+ override def writeTo(sink: BufferedSink): Unit =
+ sink.writeAll(Okio.source(b))
+ override def contentType(): MediaType = null
+ })
+ case PathBody(b, _) =>
+ Some(OkHttpRequestBody.create(null, b.toFile))
+ case StreamBody(s) =>
+ streamToRequestBody(s)
+ case MultipartBody(ps) =>
+ val b = new OkHttpMultipartBody.Builder()
+ .setType(OkHttpMultipartBody.FORM)
+ ps.foreach(addMultipart(b, _))
+ Some(b.build())
+ }
+ }
+
+ private def addMultipart(builder: OkHttpMultipartBody.Builder,
+ mp: Multipart): Unit = {
+ val allHeaders = mp.additionalHeaders + (ContentDispositionHeader -> mp.contentDispositionHeaderValue)
+ val headers = Headers.of(allHeaders.asJava)
+
+ bodyToOkHttp(mp.body).foreach(builder.addPart(headers, _))
+ }
+
+ private[okhttp] def readResponse[T](
+ res: OkHttpResponse,
+ responseAs: ResponseAs[T, S]): R[Response[T]] = {
+
+ val code = res.code()
+
+ val body = if (codeIsSuccess(code)) {
+ responseMonad.map(responseHandler(res).handle(responseAs, responseMonad))(
+ Right(_))
+ } else {
+ responseMonad.map(responseHandler(res).handle(asString, responseMonad))(
+ Left(_))
+ }
+
+ val headers = res
+ .headers()
+ .names()
+ .asScala
+ .flatMap(name => res.headers().values(name).asScala.map((name, _)))
+
+ responseMonad.map(body)(Response(_, res.code(), headers.toList, Nil))
+ }
+
+ private def responseHandler(res: OkHttpResponse) =
+ new EagerResponseHandler[S] {
+ override def handleBasic[T](bra: BasicResponseAs[T, S]): Try[T] =
+ bra match {
+ case IgnoreResponse =>
+ Try(res.close())
+ case ResponseAsString(encoding) =>
+ val body = Try(
+ res.body().source().readString(Charset.forName(encoding)))
+ res.close()
+ body
+ case ResponseAsByteArray =>
+ val body = Try(res.body().bytes())
+ res.close()
+ body
+ case ras @ ResponseAsStream() =>
+ responseBodyToStream(res).map(ras.responseIsStream)
+ case ResponseAsFile(file, overwrite) =>
+ val body = Try(
+ ResponseAs.saveFile(file, res.body().byteStream(), overwrite))
+ res.close()
+ body
+ }
+ }
+
+ def streamToRequestBody(stream: S): Option[OkHttpRequestBody] = None
+
+ def responseBodyToStream(res: OkHttpResponse): Try[S] =
+ Failure(new IllegalStateException("Streaming isn't supported"))
+
+ override def close(): Unit = if (closeClient) {
+ client.dispatcher().executorService().shutdown()
+ }
+}
+
+object OkHttpBackend {
+
+ private[okhttp] def defaultClient(readTimeout: Long,
+ connectionTimeout: Long): OkHttpClient =
+ new OkHttpClient.Builder()
+ .followRedirects(false)
+ .followSslRedirects(false)
+ .connectTimeout(connectionTimeout, TimeUnit.MILLISECONDS)
+ .readTimeout(readTimeout, TimeUnit.MILLISECONDS)
+ .build()
+
+ private[okhttp] def updateClientIfCustomReadTimeout[T, S](
+ r: Request[T, S],
+ client: OkHttpClient): OkHttpClient = {
+ val readTimeout = r.options.readTimeout
+ if (readTimeout == DefaultReadTimeout) client
+ else
+ client
+ .newBuilder()
+ .readTimeout(if (readTimeout.isFinite()) readTimeout.toMillis else 0,
+ TimeUnit.MILLISECONDS)
+ .build()
+
+ }
+}
+
+class OkHttpSyncBackend private (client: OkHttpClient, closeClient: Boolean)
+ extends OkHttpBackend[Id, Nothing](client, closeClient) {
+ override def send[T](r: Request[T, Nothing]): Response[T] = {
+ val request = convertRequest(r)
+ val response = OkHttpBackend
+ .updateClientIfCustomReadTimeout(r, client)
+ .newCall(request)
+ .execute()
+ readResponse(response, r.response)
+ }
+
+ override def responseMonad: MonadError[Id] = IdMonad
+}
+
+object OkHttpSyncBackend {
+ private def apply(client: OkHttpClient,
+ closeClient: Boolean): SttpBackend[Id, Nothing] =
+ new FollowRedirectsBackend[Id, Nothing](
+ new OkHttpSyncBackend(client, closeClient))
+
+ def apply(
+ connectionTimeout: FiniteDuration = SttpBackend.DefaultConnectionTimeout)
+ : SttpBackend[Id, Nothing] =
+ OkHttpSyncBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis,
+ connectionTimeout.toMillis),
+ closeClient = true)
+
+ def usingClient(client: OkHttpClient): SttpBackend[Id, Nothing] =
+ OkHttpSyncBackend(client, closeClient = false)
+}
+
+abstract class OkHttpAsyncBackend[R[_], S](client: OkHttpClient,
+ rm: MonadAsyncError[R],
+ closeClient: Boolean)
+ extends OkHttpBackend[R, S](client, closeClient) {
+ override def send[T](r: Request[T, S]): R[Response[T]] = {
+ val request = convertRequest(r)
+
+ rm.flatten(rm.async[R[Response[T]]] { cb =>
+ def success(r: R[Response[T]]) = cb(Right(r))
+ def error(t: Throwable) = cb(Left(t))
+
+ OkHttpBackend
+ .updateClientIfCustomReadTimeout(r, client)
+ .newCall(request)
+ .enqueue(new Callback {
+ override def onFailure(call: Call, e: IOException): Unit =
+ error(e)
+
+ override def onResponse(call: Call, response: OkHttpResponse): Unit =
+ try success(readResponse(response, r.response))
+ catch { case e: Exception => error(e) }
+ })
+ })
+ }
+
+ override def responseMonad: MonadError[R] = rm
+}
+
+class OkHttpFutureBackend private (client: OkHttpClient, closeClient: Boolean)(
+ implicit ec: ExecutionContext)
+ extends OkHttpAsyncBackend[Future, Nothing](client,
+ new FutureMonad,
+ closeClient) {}
+
+object OkHttpFutureBackend {
+ private def apply(client: OkHttpClient, closeClient: Boolean)(
+ implicit ec: ExecutionContext): SttpBackend[Future, Nothing] =
+ new FollowRedirectsBackend[Future, Nothing](
+ new OkHttpFutureBackend(client, closeClient))
+
+ def apply(connectionTimeout: FiniteDuration =
+ SttpBackend.DefaultConnectionTimeout)(
+ implicit ec: ExecutionContext = ExecutionContext.Implicits.global)
+ : SttpBackend[Future, Nothing] =
+ OkHttpFutureBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis,
+ connectionTimeout.toMillis),
+ closeClient = true)
+
+ def usingClient(client: OkHttpClient)(implicit ec: ExecutionContext =
+ ExecutionContext.Implicits.global)
+ : SttpBackend[Future, Nothing] =
+ OkHttpFutureBackend(client, closeClient = false)
+}