package com.softwaremill.sttp
import java.io.{ByteArrayInputStream, IOException}
import java.nio.ByteBuffer
import java.nio.file.Paths
import java.time.{ZoneId, ZonedDateTime}
import akka.http.scaladsl.coding.{Deflate, Gzip, NoCoding}
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.CacheDirectives._
import akka.http.scaladsl.model.headers._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.server.directives.Credentials
import akka.util.ByteString
import better.files._
import com.softwaremill.sttp.akkahttp.AkkaHttpBackend
import com.softwaremill.sttp.asynchttpclient.cats.AsyncHttpClientCatsBackend
import com.softwaremill.sttp.asynchttpclient.future.AsyncHttpClientFutureBackend
import com.softwaremill.sttp.asynchttpclient.monix.AsyncHttpClientMonixBackend
import com.softwaremill.sttp.asynchttpclient.scalaz.AsyncHttpClientScalazBackend
import com.softwaremill.sttp.impl.cats.convertCatsIOToFuture
import com.softwaremill.sttp.impl.monix.convertMonixTaskToFuture
import com.softwaremill.sttp.impl.scalaz.convertScalazTaskToFuture
import com.softwaremill.sttp.okhttp.monix.OkHttpMonixBackend
import com.softwaremill.sttp.okhttp.{OkHttpFutureBackend, OkHttpSyncBackend}
import com.softwaremill.sttp.testing.streaming.ConvertToFuture
import com.typesafe.scalalogging.StrictLogging
import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures}
import org.scalatest.{path => _, _}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.higherKinds
class BasicTests
extends FlatSpec
with Matchers
with BeforeAndAfterAll
with ScalaFutures
with OptionValues
with StrictLogging
with IntegrationPatience
with TestHttpServer
with ForceWrapped
with BeforeAndAfterEach {
override def afterEach() {
val file = File(outPath)
if (file.exists) file.delete()
}
private def paramsToString(m: Map[String, String]): String =
m.toList.sortBy(_._1).map(p => s"${p._1}=${p._2}").mkString(" ")
private val textFile =
new java.io.File("tests/src/test/resources/textfile.txt")
private val binaryFile =
new java.io.File("tests/src/test/resources/binaryfile.jpg")
private val outPath = Paths.get("out")
private val textWithSpecialCharacters = "Żółć!"
override val serverRoutes: Route =
pathPrefix("echo") {
pathPrefix("form_params") {
formFieldMap { params =>
path("as_string") {
complete(paramsToString(params))
} ~
path("as_params") {
complete(FormData(params))
}
}
} ~ get {
parameterMap { params =>
complete(List("GET", "/echo", paramsToString(params))
.filter(_.nonEmpty)
.mkString(" "))
}
} ~
post {
parameterMap { params =>
entity(as[String]) { body: String =>
complete(List("POST", "/echo", paramsToString(params), body)
.filter(_.nonEmpty)
.mkString(" "))
}
}
}
} ~ path("set_headers") {
get {
respondWithHeader(`Cache-Control`(`max-age`(1000L))) {
respondWithHeader(`Cache-Control`(`no-cache`)) {
complete("ok")
}
}
}
} ~ pathPrefix("set_cookies") {
path("with_expires") {
setCookie(HttpCookie("c", "v", expires = Some(DateTime(1997, 12, 8, 12, 49, 12)))) {
complete("ok")
}
} ~ get {
setCookie(HttpCookie("cookie1", "value1", secure = true, httpOnly = true, maxAge = Some(123L))) {
setCookie(HttpCookie("cookie2", "value2")) {
setCookie(HttpCookie("cookie3", "", domain = Some("xyz"), path = Some("a/b/c"))) {
complete("ok")
}
}
}
}
} ~ path("secure_basic") {
authenticateBasic("test realm", {
case c @ Credentials.Provided(un) if un == "adam" && c.verify("1234") =>
Some(un)
case _ => None
}) { userName =>
complete(s"Hello, $userName!")
}
} ~ path("compress") {
encodeResponseWith(Gzip, Deflate, NoCoding) {
complete("I'm compressed!")
}
} ~ pathPrefix("download") {
path("binary") {
getFromFile(binaryFile)
} ~ path("text") {
getFromFile(textFile)
}
} ~ pathPrefix("multipart") {
entity(as[akka.http.scaladsl.model.Multipart.FormData]) { fd =>
complete {
fd.parts
.mapAsync(1) { p =>
val fv = p.entity.dataBytes.runFold(ByteString())(_ ++ _)
fv.map(_.utf8String)
.map(v => p.name + "=" + v + p.filename.fold("")(fn => s" ($fn)"))
}
.runFold(Vector.empty[String])(_ :+ _)
.map(v => v.mkString(", "))
}
}
} ~ pathPrefix("redirect") {
path("r1") {
redirect("/redirect/r2", StatusCodes.TemporaryRedirect)
} ~
path("r2") {
redirect("/redirect/r3", StatusCodes.PermanentRedirect)
} ~
path("r3") {
redirect("/redirect/r4", StatusCodes.Found)
} ~
path("r4") {
complete("819")
} ~
path("loop") {
redirect("/redirect/loop", StatusCodes.Found)
}
} ~ pathPrefix("timeout") {
complete {
akka.pattern.after(1.second, using = actorSystem.scheduler)(Future.successful("Done"))
}
} ~ path("empty_unauthorized_response") {
post {
import akka.http.scaladsl.model._
complete(
HttpResponse(
status = StatusCodes.Unauthorized,
headers = Nil,
entity = HttpEntity.Empty,
protocol = HttpProtocols.`HTTP/1.1`
))
}
} ~ path("respond_with_iso_8859_2") {
get { ctx =>
val entity =
HttpEntity(MediaTypes.`text/plain`.withCharset(HttpCharset.custom("ISO-8859-2")), textWithSpecialCharacters)
ctx.complete(HttpResponse(200, entity = entity))
}
}
override def port = 51823
var closeBackends: List[() => Unit] = Nil
runTests("HttpURLConnection")(HttpURLConnectionBackend(), ConvertToFuture.id)
runTests("TryHttpURLConnection")(TryHttpURLConnectionBackend(), ConvertToFuture.scalaTry)
runTests("Akka HTTP")(AkkaHttpBackend.usingActorSystem(actorSystem), ConvertToFuture.future)
runTests("Async Http Client - Future")(AsyncHttpClientFutureBackend(), ConvertToFuture.future)
runTests("Async Http Client - Scalaz")(AsyncHttpClientScalazBackend(), convertScalazTaskToFuture)
runTests("Async Http Client - Monix")(AsyncHttpClientMonixBackend(), convertMonixTaskToFuture)
runTests("Async Http Client - Cats IO")(AsyncHttpClientCatsBackend[cats.effect.IO](), convertCatsIOToFuture)
runTests("OkHttpSyncClientHandler")(OkHttpSyncBackend(), ConvertToFuture.id)
runTests("OkHttpAsyncClientHandler - Future")(OkHttpFutureBackend(), ConvertToFuture.future)
runTests("OkHttpAsyncClientHandler - Monix")(OkHttpMonixBackend(), convertMonixTaskToFuture)
def runTests[R[_]](name: String)(implicit
backend: SttpBackend[R, Nothing],
convertToFuture: ConvertToFuture[R]): Unit = {
closeBackends = (() => backend.close()) :: closeBackends
val postEcho = sttp.post(uri"$endpoint/echo")
val testBody = "this is the body"
val testBodyBytes = testBody.getBytes("UTF-8")
val expectedPostEchoResponse = "POST /echo this is the body"
val sttpIgnore = com.softwaremill.sttp.ignore
parseResponseTests()
parameterTests()
bodyTests()
headerTests()
errorsTests()
cookiesTests()
authTests()
compressionTests()
downloadFileTests()
multipartTests()
redirectTests()
timeoutTests()
emptyResponseTests()
encodingTests()
def parseResponseTests(): Unit = {
name should "parse response as string" in {
val response = postEcho.body(testBody).send().force()
response.unsafeBody should be(expectedPostEchoResponse)
}
name should "parse response as string with mapping using map" in {
val response = postEcho
.body(testBody)
.response(asString.map(_.length))
.send()
.force()
response.unsafeBody should be(expectedPostEchoResponse.length)
}
name should "parse response as string with mapping using mapResponse" in {
val response = postEcho
.body(testBody)
.mapResponse(_.length)
.send()
.force()
response.unsafeBody should be(expectedPostEchoResponse.length)
}
name should "parse response as a byte array" in {
val response =
postEcho.body(testBody).response(asByteArray).send().force()
val fc = new String(response.unsafeBody, "UTF-8")
fc should be(expectedPostEchoResponse)
}
name should "parse response as parameters" in {
val params = List("a" -> "b", "c" -> "d", "e=" -> "&f")
val response = sttp
.post(uri"$endpoint/echo/form_params/as_params")
.body(params: _*)
.response(asParams)
.send()
.force()
response.unsafeBody.toList should be(params)
}
}
def parameterTests(): Unit = {
name should "make a get request with parameters" in {
val response = sttp
.get(uri"$endpoint/echo?p2=v2&p1=v1")
.send()
.force()
response.unsafeBody should be("GET /echo p1=v1 p2=v2")
}
}
def bodyTests(): Unit = {
name should "post a string" in {
val response = postEcho.body(testBody).send().force()
response.unsafeBody should be(expectedPostEchoResponse)
}
name should "post a byte array" in {
val response =
postEcho.body(testBodyBytes).send().force()
response.unsafeBody should be(expectedPostEchoResponse)
}
name should "post an input stream" in {
val response = postEcho
.body(new ByteArrayInputStream(testBodyBytes))
.send()
.force()
response.unsafeBody should be(expectedPostEchoResponse)
}
name should "post a byte buffer" in {
val response = postEcho
.body(ByteBuffer.wrap(testBodyBytes))
.send()
.force()
response.unsafeBody should be(expectedPostEchoResponse)
}
name should "post a file" in {
val f = File.newTemporaryFile().write(testBody)
try {
val response = postEcho.body(f.toJava).send().force()
response.unsafeBody should be(expectedPostEchoResponse)
} finally f.delete()
}
name should "post a path" in {
val f = File.newTemporaryFile().write(testBody)
try {
val response =
postEcho.body(f.toJava.toPath).send().force()
response.unsafeBody should be(expectedPostEchoResponse)
} finally f.delete()
}
name should "post form data" in {
val response = sttp
.post(uri"$endpoint/echo/form_params/as_string")
.body("a" -> "b", "c" -> "d")
.send()
.force()
response.unsafeBody should be("a=b c=d")
}
name should "post form data with special characters" in {
val response = sttp
.post(uri"$endpoint/echo/form_params/as_string")
.body("a=" -> "/b", "c:" -> "/d")
.send()
.force()
response.unsafeBody should be("a==/b c:=/d")
}
name should "post without a body" in {
val response = postEcho.send().force()
response.unsafeBody should be("POST /echo")
}
}
def headerTests(): Unit = {
val getHeaders = sttp.get(uri"$endpoint/set_headers")
name should "read response headers" in {
val response = getHeaders.response(sttpIgnore).send().force()
response.headers should have length (6)
response.headers("Cache-Control").toSet should be(Set("no-cache", "max-age=1000"))
response.header("Server") should be('defined)
response.header("server") should be('defined)
response.header("Server").get should startWith("akka-http")
response.contentType should be(Some("text/plain; charset=UTF-8"))
response.contentLength should be(Some(2L))
}
}
def errorsTests(): Unit = {
val getHeaders = sttp.post(uri"$endpoint/set_headers")
name should "return 405 when method not allowed" in {
val response = getHeaders.response(sttpIgnore).send().force()
response.code should be(405)
response.isClientError should be(true)
response.body should be('left)
}
}
def cookiesTests(): Unit = {
name should "read response cookies" in {
val response =
sttp
.get(uri"$endpoint/set_cookies")
.response(sttpIgnore)
.send()
.force()
response.cookies should have length (3)
response.cookies.toSet should be(
Set(
Cookie("cookie1", "value1", secure = true, httpOnly = true, maxAge = Some(123L)),
Cookie("cookie2", "value2"),
Cookie("cookie3", "", domain = Some("xyz"), path = Some("a/b/c"))
))
}
name should "read response cookies with the expires attribute" in {
val response = sttp
.get(uri"$endpoint/set_cookies/with_expires")
.response(sttpIgnore)
.send()
.force()
response.cookies should have length (1)
val c = response.cookies(0)
c.name should be("c")
c.value should be("v")
c.expires.map(_.toInstant.toEpochMilli) should be(
Some(
ZonedDateTime
.of(1997, 12, 8, 12, 49, 12, 0, ZoneId.of("GMT"))
.toInstant
.toEpochMilli
))
}
}
def authTests(): Unit = {
val secureBasic = sttp.get(uri"$endpoint/secure_basic")
name should "return a 401 when authorization fails" in {
val req = secureBasic
val resp = req.send().force()
resp.code should be(401)
resp.header("WWW-Authenticate") should be(Some("""Basic realm="test realm",charset=UTF-8"""))
}
name should "perform basic authorization" in {
val req = secureBasic.auth.basic("adam", "1234")
val resp = req.send().force()
resp.code should be(200)
resp.unsafeBody should be("Hello, adam!")
}
}
def compressionTests(): Unit = {
val compress = sttp.get(uri"$endpoint/compress")
val decompressedBody = "I'm compressed!"
name should "decompress using the default accept encoding header" in {
val req = compress
val resp = req.send().force()
resp.unsafeBody should be(decompressedBody)
}
name should "decompress using gzip" in {
val req =
compress.header("Accept-Encoding", "gzip", replaceExisting = true)
val resp = req.send().force()
resp.unsafeBody should be(decompressedBody)
}
name should "decompress using deflate" in {
val req =
compress.header("Accept-Encoding", "deflate", replaceExisting = true)
val resp = req.send().force()
resp.unsafeBody should be(decompressedBody)
}
name should "work despite providing an unsupported encoding" in {
val req =
compress.header("Accept-Encoding", "br", replaceExisting = true)
val resp = req.send().force()
resp.unsafeBody should be(decompressedBody)
}
}
def downloadFileTests(): Unit = {
import CustomMatchers._
name should "download a binary file using asFile" in {
val file = outPath.resolve("binaryfile.jpg").toFile
val req =
sttp.get(uri"$endpoint/download/binary").response(asFile(file))
val resp = req.send().force()
resp.unsafeBody shouldBe file
file should exist
file should haveSameContentAs(binaryFile)
}
name should "download a text file using asFile" in {
val file = outPath.resolve("textfile.txt").toFile
val req =
sttp.get(uri"$endpoint/download/text").response(asFile(file))
val resp = req.send().force()
resp.unsafeBody shouldBe file
file should exist
file should haveSameContentAs(textFile)
}
name should "download a binary file using asPath" in {
val path = outPath.resolve("binaryfile.jpg")
val req =
sttp.get(uri"$endpoint/download/binary").response(asPath(path))
val resp = req.send().force()
resp.unsafeBody shouldBe path
path.toFile should exist
path.toFile should haveSameContentAs(binaryFile)
}
name should "download a text file using asPath" in {
val path = outPath.resolve("textfile.txt")
val req =
sttp.get(uri"$endpoint/download/text").response(asPath(path))
val resp = req.send().force()
resp.unsafeBody shouldBe path
path.toFile should exist
path.toFile should haveSameContentAs(textFile)
}
name should "fail at trying to save file to a restricted location" in {
val path = Paths.get("/").resolve("textfile.txt")
val req =
sttp.get(uri"$endpoint/download/text").response(asPath(path))
val caught = intercept[IOException] {
req.send().force()
}
caught.getMessage shouldBe "Permission denied"
}
name should "fail when file exists and overwrite flag is false" in {
val path = outPath.resolve("textfile.txt")
path.toFile.getParentFile.mkdirs()
path.toFile.createNewFile()
val req =
sttp.get(uri"$endpoint/download/text").response(asPath(path))
val caught = intercept[IOException] {
req.send().force()
}
caught.getMessage shouldBe s"File ${path.toFile.getAbsolutePath} exists - overwriting prohibited"
}
name should "not fail when file exists and overwrite flag is true" in {
val path = outPath.resolve("textfile.txt")
path.toFile.getParentFile.mkdirs()
path.toFile.createNewFile()
val req =
sttp
.get(uri"$endpoint/download/text")
.response(asPath(path, overwrite = true))
val resp = req.send().force()
resp.unsafeBody shouldBe path
path.toFile should exist
path.toFile should haveSameContentAs(textFile)
}
}
def multipartTests(): Unit = {
val mp = sttp.post(uri"$endpoint/multipart")
name should "send a multipart message" in {
val req = mp.multipartBody(multipart("p1", "v1"), multipart("p2", "v2"))
val resp = req.send().force()
resp.unsafeBody should be("p1=v1, p2=v2")
}
name should "send a multipart message with filenames" in {
val req = mp.multipartBody(multipart("p1", "v1").fileName("f1"), multipart("p2", "v2").fileName("f2"))
val resp = req.send().force()
resp.unsafeBody should be("p1=v1 (f1), p2=v2 (f2)")
}
name should "send a multipart message with a file" in {
val f = File.newTemporaryFile().write(testBody)
try {
val req =
mp.multipartBody(multipart("p1", f.toJava), multipart("p2", "v2"))
val resp = req.send().force()
resp.unsafeBody should be(s"p1=$testBody (${f.name}), p2=v2")
} finally f.delete()
}
}
def redirectTests(): Unit = {
val r1 = sttp.post(uri"$endpoint/redirect/r1")
val r2 = sttp.post(uri"$endpoint/redirect/r2")
val r3 = sttp.post(uri"$endpoint/redirect/r3")
val r4response = "819"
val loop = sttp.post(uri"$endpoint/redirect/loop")
name should "not redirect when redirects shouldn't be followed (temporary)" in {
val resp = r1.followRedirects(false).send().force()
resp.code should be(307)
resp.body should be('left)
resp.history should be('empty)
}
name should "not redirect when redirects shouldn't be followed (permanent)" in {
val resp = r2.followRedirects(false).send().force()
resp.code should be(308)
resp.body should be('left)
}
name should "redirect when redirects should be followed" in {
val resp = r2.send().force()
resp.code should be(200)
resp.unsafeBody should be(r4response)
}
name should "redirect twice when redirects should be followed" in {
val resp = r1.send().force()
resp.code should be(200)
resp.unsafeBody should be(r4response)
}
name should "redirect when redirects should be followed, and the response is parsed" in {
val resp = r2.response(asString.map(_.toInt)).send().force()
resp.code should be(200)
resp.unsafeBody should be(r4response.toInt)
}
name should "keep a single history entry of redirect responses" in {
val resp = r3.send().force()
resp.code should be(200)
resp.unsafeBody should be(r4response)
resp.history should have size (1)
resp.history(0).code should be(302)
}
name should "keep whole history of redirect responses" in {
val resp = r1.send().force()
resp.code should be(200)
resp.unsafeBody should be(r4response)
resp.history should have size (3)
resp.history(0).code should be(307)
resp.history(1).code should be(308)
resp.history(2).code should be(302)
}
name should "break redirect loops" in {
val resp = loop.send().force()
resp.code should be(0)
resp.history should have size (FollowRedirectsBackend.MaxRedirects)
}
name should "break redirect loops after user-specified count" in {
val maxRedirects = 10
val resp = loop.maxRedirects(maxRedirects).send().force()
resp.code should be(0)
resp.history should have size (maxRedirects)
}
name should "not redirect when maxRedirects is less than or equal to 0" in {
val resp = loop.maxRedirects(-1).send().force()
resp.code should be(302)
resp.body should be('left)
resp.history should be('empty)
}
}
def timeoutTests(): Unit = {
name should "fail if read timeout is not big enough" in {
val request = sttp
.get(uri"$endpoint/timeout")
.readTimeout(200.milliseconds)
.response(asString)
intercept[Throwable] {
request.send().force()
}
}
name should "not fail if read timeout is big enough" in {
val request = sttp
.get(uri"$endpoint/timeout")
.readTimeout(5.seconds)
.response(asString)
request.send().force().unsafeBody should be("Done")
}
}
def emptyResponseTests(): Unit = {
val postEmptyResponse = sttp
.post(uri"$endpoint/empty_unauthorized_response")
.body("{}")
.contentType("application/json")
name should "parse an empty error response as empty string" in {
val response = postEmptyResponse.send().force()
response.body should be(Left(""))
}
}
def encodingTests(): Unit = {
name should "read response body encoded using ISO-8859-2, as specified in the header, overriding the default" in {
val request = sttp.get(uri"$endpoint/respond_with_iso_8859_2")
request.send().force().unsafeBody should be(textWithSpecialCharacters)
}
}
}
override protected def afterAll(): Unit = {
closeBackends.foreach(_())
super.afterAll()
}
}