diff options
Diffstat (limited to 'okhttp-backend')
2 files changed, 377 insertions, 0 deletions
diff --git a/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala b/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala new file mode 100644 index 0000000..4b24e65 --- /dev/null +++ b/okhttp-backend/monix/src/main/scala/com/softwaremill/sttp/okhttp/monix/OkHttpMonixBackend.scala @@ -0,0 +1,122 @@ +package com.softwaremill.sttp.okhttp.monix + +import java.nio.ByteBuffer +import java.util.concurrent.ArrayBlockingQueue + +import com.softwaremill.sttp.{SttpBackend, _} +import com.softwaremill.sttp.okhttp.{OkHttpAsyncBackend, OkHttpBackend} +import monix.eval.Task +import monix.execution.Ack.Continue +import monix.execution.{Ack, Cancelable, Scheduler} +import monix.reactive.Observable +import monix.reactive.observers.Subscriber +import okhttp3.{MediaType, OkHttpClient, RequestBody => OkHttpRequestBody} +import okio.BufferedSink + +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration +import scala.util.{Failure, Success, Try} + +class OkHttpMonixBackend private (client: OkHttpClient, closeClient: Boolean)( + implicit s: Scheduler) + extends OkHttpAsyncBackend[Task, Observable[ByteBuffer]](client, + TaskMonad, + closeClient) { + + override def streamToRequestBody( + stream: Observable[ByteBuffer]): Option[OkHttpRequestBody] = + Some(new OkHttpRequestBody() { + override def writeTo(sink: BufferedSink): Unit = + toIterable(stream) map (_.array()) foreach sink.write + override def contentType(): MediaType = null + }) + + override def responseBodyToStream( + res: okhttp3.Response): Try[Observable[ByteBuffer]] = + Success( + Observable + .fromInputStream(res.body().byteStream()) + .map(ByteBuffer.wrap) + .doAfterTerminate(_ => res.close())) + + private def toIterable[T](observable: Observable[T])( + implicit s: Scheduler): Iterable[T] = + new Iterable[T] { + override def iterator: Iterator[T] = new Iterator[T] { + case object Completed extends Exception + + val blockingQueue = new ArrayBlockingQueue[Either[Throwable, T]](1) + + observable.executeWithFork.subscribe(new Subscriber[T] { + override implicit def scheduler: Scheduler = s + + override def onError(ex: Throwable): Unit = { + blockingQueue.put(Left(ex)) + } + + override def onComplete(): Unit = { + blockingQueue.put(Left(Completed)) + } + + override def onNext(elem: T): Future[Ack] = { + blockingQueue.put(Right(elem)) + Continue + } + }) + + var value: T = _ + + override def hasNext: Boolean = + blockingQueue.take() match { + case Left(Completed) => false + case Right(elem) => + value = elem + true + case Left(ex) => throw ex + } + + override def next(): T = value + } + } +} + +object OkHttpMonixBackend { + private def apply(client: OkHttpClient, closeClient: Boolean)( + implicit s: Scheduler): SttpBackend[Task, Observable[ByteBuffer]] = + new FollowRedirectsBackend(new OkHttpMonixBackend(client, closeClient)(s)) + + def apply(connectionTimeout: FiniteDuration = + SttpBackend.DefaultConnectionTimeout)( + implicit s: Scheduler = Scheduler.Implicits.global) + : SttpBackend[Task, Observable[ByteBuffer]] = + OkHttpMonixBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, + connectionTimeout.toMillis), + closeClient = true)(s) + + def usingClient(client: OkHttpClient)(implicit s: Scheduler = + Scheduler.Implicits.global) + : SttpBackend[Task, Observable[ByteBuffer]] = + OkHttpMonixBackend(client, closeClient = false)(s) +} + +private[monix] object TaskMonad extends MonadAsyncError[Task] { + override def unit[T](t: T): Task[T] = Task.now(t) + + override def map[T, T2](fa: Task[T])(f: (T) => T2): Task[T2] = fa.map(f) + + override def flatMap[T, T2](fa: Task[T])(f: (T) => Task[T2]): Task[T2] = + fa.flatMap(f) + + override def async[T]( + register: ((Either[Throwable, T]) => Unit) => Unit): Task[T] = + Task.async { (_, cb) => + register { + case Left(t) => cb(Failure(t)) + case Right(t) => cb(Success(t)) + } + + Cancelable.empty + } + + override def error[T](t: Throwable): Task[T] = Task.raiseError(t) +} diff --git a/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala new file mode 100644 index 0000000..ead1bde --- /dev/null +++ b/okhttp-backend/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala @@ -0,0 +1,255 @@ +package com.softwaremill.sttp.okhttp + +import java.io.IOException +import java.nio.charset.Charset +import java.util.concurrent.TimeUnit + +import com.softwaremill.sttp._ +import ResponseAs.EagerResponseHandler +import okhttp3.internal.http.HttpMethod +import okhttp3.{ + Call, + Callback, + Headers, + MediaType, + OkHttpClient, + MultipartBody => OkHttpMultipartBody, + Request => OkHttpRequest, + RequestBody => OkHttpRequestBody, + Response => OkHttpResponse +} +import okio.{BufferedSink, Okio} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.{ExecutionContext, Future} +import scala.language.higherKinds +import scala.util.{Failure, Try} + +abstract class OkHttpBackend[R[_], S](client: OkHttpClient, + closeClient: Boolean) + extends SttpBackend[R, S] { + private[okhttp] def convertRequest[T](request: Request[T, S]): OkHttpRequest = { + val builder = new OkHttpRequest.Builder() + .url(request.uri.toString) + + val body = bodyToOkHttp(request.body) + builder.method(request.method.m, body.getOrElse { + if (HttpMethod.requiresRequestBody(request.method.m)) + OkHttpRequestBody.create(null, "") + else null + }) + + //OkHttp support automatic gzip compression + request.headers + .filter(_._1.equalsIgnoreCase(AcceptEncodingHeader) == false) + .foreach { + case (name, value) => builder.addHeader(name, value) + } + + builder.build() + } + + private def bodyToOkHttp[T](body: RequestBody[S]): Option[OkHttpRequestBody] = { + body match { + case NoBody => None + case StringBody(b, _, _) => + Some(OkHttpRequestBody.create(null, b)) + case ByteArrayBody(b, _) => + Some(OkHttpRequestBody.create(null, b)) + case ByteBufferBody(b, _) => + Some(OkHttpRequestBody.create(null, b.array())) + case InputStreamBody(b, _) => + Some(new OkHttpRequestBody() { + override def writeTo(sink: BufferedSink): Unit = + sink.writeAll(Okio.source(b)) + override def contentType(): MediaType = null + }) + case PathBody(b, _) => + Some(OkHttpRequestBody.create(null, b.toFile)) + case StreamBody(s) => + streamToRequestBody(s) + case MultipartBody(ps) => + val b = new OkHttpMultipartBody.Builder() + .setType(OkHttpMultipartBody.FORM) + ps.foreach(addMultipart(b, _)) + Some(b.build()) + } + } + + private def addMultipart(builder: OkHttpMultipartBody.Builder, + mp: Multipart): Unit = { + val allHeaders = mp.additionalHeaders + (ContentDispositionHeader -> mp.contentDispositionHeaderValue) + val headers = Headers.of(allHeaders.asJava) + + bodyToOkHttp(mp.body).foreach(builder.addPart(headers, _)) + } + + private[okhttp] def readResponse[T]( + res: OkHttpResponse, + responseAs: ResponseAs[T, S]): R[Response[T]] = { + + val code = res.code() + + val body = if (codeIsSuccess(code)) { + responseMonad.map(responseHandler(res).handle(responseAs, responseMonad))( + Right(_)) + } else { + responseMonad.map(responseHandler(res).handle(asString, responseMonad))( + Left(_)) + } + + val headers = res + .headers() + .names() + .asScala + .flatMap(name => res.headers().values(name).asScala.map((name, _))) + + responseMonad.map(body)(Response(_, res.code(), headers.toList, Nil)) + } + + private def responseHandler(res: OkHttpResponse) = + new EagerResponseHandler[S] { + override def handleBasic[T](bra: BasicResponseAs[T, S]): Try[T] = + bra match { + case IgnoreResponse => + Try(res.close()) + case ResponseAsString(encoding) => + val body = Try( + res.body().source().readString(Charset.forName(encoding))) + res.close() + body + case ResponseAsByteArray => + val body = Try(res.body().bytes()) + res.close() + body + case ras @ ResponseAsStream() => + responseBodyToStream(res).map(ras.responseIsStream) + case ResponseAsFile(file, overwrite) => + val body = Try( + ResponseAs.saveFile(file, res.body().byteStream(), overwrite)) + res.close() + body + } + } + + def streamToRequestBody(stream: S): Option[OkHttpRequestBody] = None + + def responseBodyToStream(res: OkHttpResponse): Try[S] = + Failure(new IllegalStateException("Streaming isn't supported")) + + override def close(): Unit = if (closeClient) { + client.dispatcher().executorService().shutdown() + } +} + +object OkHttpBackend { + + private[okhttp] def defaultClient(readTimeout: Long, + connectionTimeout: Long): OkHttpClient = + new OkHttpClient.Builder() + .followRedirects(false) + .followSslRedirects(false) + .connectTimeout(connectionTimeout, TimeUnit.MILLISECONDS) + .readTimeout(readTimeout, TimeUnit.MILLISECONDS) + .build() + + private[okhttp] def updateClientIfCustomReadTimeout[T, S]( + r: Request[T, S], + client: OkHttpClient): OkHttpClient = { + val readTimeout = r.options.readTimeout + if (readTimeout == DefaultReadTimeout) client + else + client + .newBuilder() + .readTimeout(if (readTimeout.isFinite()) readTimeout.toMillis else 0, + TimeUnit.MILLISECONDS) + .build() + + } +} + +class OkHttpSyncBackend private (client: OkHttpClient, closeClient: Boolean) + extends OkHttpBackend[Id, Nothing](client, closeClient) { + override def send[T](r: Request[T, Nothing]): Response[T] = { + val request = convertRequest(r) + val response = OkHttpBackend + .updateClientIfCustomReadTimeout(r, client) + .newCall(request) + .execute() + readResponse(response, r.response) + } + + override def responseMonad: MonadError[Id] = IdMonad +} + +object OkHttpSyncBackend { + private def apply(client: OkHttpClient, + closeClient: Boolean): SttpBackend[Id, Nothing] = + new FollowRedirectsBackend[Id, Nothing]( + new OkHttpSyncBackend(client, closeClient)) + + def apply( + connectionTimeout: FiniteDuration = SttpBackend.DefaultConnectionTimeout) + : SttpBackend[Id, Nothing] = + OkHttpSyncBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, + connectionTimeout.toMillis), + closeClient = true) + + def usingClient(client: OkHttpClient): SttpBackend[Id, Nothing] = + OkHttpSyncBackend(client, closeClient = false) +} + +abstract class OkHttpAsyncBackend[R[_], S](client: OkHttpClient, + rm: MonadAsyncError[R], + closeClient: Boolean) + extends OkHttpBackend[R, S](client, closeClient) { + override def send[T](r: Request[T, S]): R[Response[T]] = { + val request = convertRequest(r) + + rm.flatten(rm.async[R[Response[T]]] { cb => + def success(r: R[Response[T]]) = cb(Right(r)) + def error(t: Throwable) = cb(Left(t)) + + OkHttpBackend + .updateClientIfCustomReadTimeout(r, client) + .newCall(request) + .enqueue(new Callback { + override def onFailure(call: Call, e: IOException): Unit = + error(e) + + override def onResponse(call: Call, response: OkHttpResponse): Unit = + try success(readResponse(response, r.response)) + catch { case e: Exception => error(e) } + }) + }) + } + + override def responseMonad: MonadError[R] = rm +} + +class OkHttpFutureBackend private (client: OkHttpClient, closeClient: Boolean)( + implicit ec: ExecutionContext) + extends OkHttpAsyncBackend[Future, Nothing](client, + new FutureMonad, + closeClient) {} + +object OkHttpFutureBackend { + private def apply(client: OkHttpClient, closeClient: Boolean)( + implicit ec: ExecutionContext): SttpBackend[Future, Nothing] = + new FollowRedirectsBackend[Future, Nothing]( + new OkHttpFutureBackend(client, closeClient)) + + def apply(connectionTimeout: FiniteDuration = + SttpBackend.DefaultConnectionTimeout)( + implicit ec: ExecutionContext = ExecutionContext.Implicits.global) + : SttpBackend[Future, Nothing] = + OkHttpFutureBackend(OkHttpBackend.defaultClient(DefaultReadTimeout.toMillis, + connectionTimeout.toMillis), + closeClient = true) + + def usingClient(client: OkHttpClient)(implicit ec: ExecutionContext = + ExecutionContext.Implicits.global) + : SttpBackend[Future, Nothing] = + OkHttpFutureBackend(client, closeClient = false) +} |