diff options
-rw-r--r-- | build.sbt | 29 | ||||
-rw-r--r-- | okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala | 133 | ||||
-rw-r--r-- | tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala | 8 |
3 files changed, 162 insertions, 8 deletions
@@ -38,13 +38,16 @@ val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" lazy val rootProject = (project in file(".")) .settings(commonSettings: _*) .settings(publishArtifact := false, name := "sttp") - .aggregate(core, - akkaHttpHandler, - asyncHttpClientHandler, - futureAsyncHttpClientHandler, - scalazAsyncHttpClientHandler, - monixAsyncHttpClientHandler, - tests) + .aggregate( + core, + akkaHttpHandler, + asyncHttpClientHandler, + futureAsyncHttpClientHandler, + scalazAsyncHttpClientHandler, + monixAsyncHttpClientHandler, + okhttpClientHandler, + tests + ) lazy val core: Project = (project in file("core")) .settings(commonSettings: _*) @@ -102,6 +105,16 @@ lazy val monixAsyncHttpClientHandler: Project = (project in file( ) ) dependsOn asyncHttpClientHandler +lazy val okhttpClientHandler: Project = (project in file( + "okhttp-client-handler")) + .settings(commonSettings: _*) + .settings( + name := "okhttp-client-handler", + libraryDependencies ++= Seq( + "com.squareup.okhttp3" % "okhttp" % "3.8.1" + ) + ) dependsOn core + lazy val tests: Project = (project in file("tests")) .settings(commonSettings: _*) .settings( @@ -116,4 +129,4 @@ lazy val tests: Project = (project in file("tests")) ).map(_ % "test"), libraryDependencies += "org.scala-lang" % "scala-compiler" % scalaVersion.value % "test" ) dependsOn (core, akkaHttpHandler, futureAsyncHttpClientHandler, scalazAsyncHttpClientHandler, -monixAsyncHttpClientHandler) +monixAsyncHttpClientHandler, okhttpClientHandler) diff --git a/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala b/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala new file mode 100644 index 0000000..534bd44 --- /dev/null +++ b/okhttp-client-handler/src/main/scala/com/softwaremill/sttp/okhttp/OkHttpClientHandler.scala @@ -0,0 +1,133 @@ +package com.softwaremill.sttp.okhttp + +import java.io.IOException +import java.nio.charset.Charset + +import com.softwaremill.sttp._ +import com.softwaremill.sttp.model._ +import okhttp3.internal.http.HttpMethod +import okhttp3.{ + Call, + Callback, + MediaType, + OkHttpClient, + Request => OkHttpRequest, + RequestBody => OkHttpRequestBody, + Response => OkHttpResponse +} +import okio.{BufferedSink, Okio} + +import scala.collection.JavaConverters._ +import scala.concurrent.{Future, Promise} +import scala.language.higherKinds + +abstract class OkHttpClientHandler[R[_], S](client: OkHttpClient) + extends SttpHandler[R, S] { + private[okhttp] def convertRequest[T]( + request: Request[T, S]): OkHttpRequest = { + val builder = new OkHttpRequest.Builder() + .url(request.uri.toURL) + + val body = setBody(request.body) + builder.method(request.method.m, body.getOrElse { + if (HttpMethod.requiresRequestBody(request.method.m)) + OkHttpRequestBody.create(null, "") + else null + }) + + //OkHttp support automatic gzip compression + request.headers.filter(_._1 != "Accept-Encoding").foreach { + case (name, value) => builder.addHeader(name, value) + } + + builder.build() + } + + private def setBody(requestBody: RequestBody[S]): Option[OkHttpRequestBody] = { + requestBody match { + case NoBody => None + case StringBody(b, encoding) => + Some(OkHttpRequestBody.create(MediaType.parse(encoding), b)) + case ByteArrayBody(b) => Some(OkHttpRequestBody.create(null, b)) + case ByteBufferBody(b) => Some(OkHttpRequestBody.create(null, b.array())) + case InputStreamBody(b) => + Some(new OkHttpRequestBody() { + override def writeTo(sink: BufferedSink): Unit = + sink.writeAll(Okio.source(b)) + override def contentType(): MediaType = null + }) + case PathBody(b) => Some(OkHttpRequestBody.create(null, b.toFile)) + case SerializableBody(f, t) => setBody(f(t)) + case StreamBody(s) => None + } + } + + private[okhttp] def readResponse[T]( + res: OkHttpResponse, + responseAs: ResponseAs[T, S]): Response[T] = { + val body = readResponseBody(res, responseAs) + + val headers = res + .headers() + .names() + .asScala + .flatMap(name => res.headers().values(name).asScala.map((name, _))) + Response(body, res.code(), headers.toList) + } + + private[okhttp] def readResponseBody[T](res: OkHttpResponse, + responseAs: ResponseAs[T, S]): T = { + responseAs match { + case IgnoreResponse => res.body().close() + case ResponseAsString(encoding) => + res.body().source().readString(Charset.forName(encoding)) + case ResponseAsByteArray => res.body().bytes() + case MappedResponseAs(raw, g) => g(readResponseBody(res, raw)) + case r @ ResponseAsParams(enc) => + r.parse(res.body().source().readString(Charset.forName(enc))) + case ResponseAsStream() => throw new IllegalStateException() + } + } +} + +class OkHttpSyncClientHandler(client: OkHttpClient) + extends OkHttpClientHandler[Id, Nothing](client) { + override def send[T](r: Request[T, Nothing]): Response[T] = { + val request = convertRequest(r) + val response = client.newCall(request).execute() + readResponse(response, r.responseAs) + } +} + +object OkHttpSyncClientHandler { + def apply(okhttpClient: OkHttpClient = new OkHttpClient()) + : OkHttpSyncClientHandler = + new OkHttpSyncClientHandler(okhttpClient) +} + +class OkHttpFutureClientHandler(client: OkHttpClient) + extends OkHttpClientHandler[Future, Nothing](client) { + + override def send[T](r: Request[T, Nothing]): Future[Response[T]] = { + val request = convertRequest(r) + val promise = Promise[Response[T]]() + + client + .newCall(request) + .enqueue(new Callback { + override def onFailure(call: Call, e: IOException): Unit = + promise.failure(e) + + override def onResponse(call: Call, response: OkHttpResponse): Unit = + promise.success(readResponse(response, r.responseAs)) + }) + + promise.future + } +} + +object OkHttpFutureClientHandler { + def apply(okhttpClient: OkHttpClient = new OkHttpClient()) + : OkHttpFutureClientHandler = + new OkHttpFutureClientHandler(okhttpClient) +} diff --git a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala index 90aab02..5f17a7a 100644 --- a/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala +++ b/tests/src/test/scala/com/softwaremill/sttp/BasicTests.scala @@ -19,6 +19,10 @@ import better.files._ import com.softwaremill.sttp.asynchttpclient.future.FutureAsyncHttpClientHandler import com.softwaremill.sttp.asynchttpclient.monix.MonixAsyncHttpClientHandler import com.softwaremill.sttp.asynchttpclient.scalaz.ScalazAsyncHttpClientHandler +import com.softwaremill.sttp.okhttp.{ + OkHttpFutureClientHandler, + OkHttpSyncClientHandler +} import scala.language.higherKinds @@ -126,6 +130,10 @@ class BasicTests ForceWrappedValue.scalazTask) runTests("Async Http Client - Monix")(MonixAsyncHttpClientHandler(), ForceWrappedValue.monixTask) + runTests("OkHttpSyncClientHandler")(OkHttpSyncClientHandler(), + ForceWrappedValue.id) + runTests("OkHttpSyncClientHandler - Future")(OkHttpFutureClientHandler(), + ForceWrappedValue.future) def runTests[R[_]](name: String)( implicit handler: SttpHandler[R, Nothing], |