From e6f0ac0289ad3685e2af5dfc17ee79c3c1170bdf Mon Sep 17 00:00:00 2001 From: Piotr Buda Date: Mon, 31 Jul 2017 11:59:08 +0200 Subject: #7: Add asFile and asPath responses --- .../sttp/HttpURLConnectionSttpHandler.scala | 17 ++++--------- .../com/softwaremill/sttp/model/ResponseAs.scala | 24 ++++++++++++++++++- .../main/scala/com/softwaremill/sttp/package.scala | 28 ++++++++++++++++++++-- 3 files changed, 54 insertions(+), 15 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala index d5c2ccd..29da886 100644 --- a/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala +++ b/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala @@ -115,19 +115,8 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] { case ResponseAsByteArray => val os = new ByteArrayOutputStream - var read = 0 - val buf = new Array[Byte](1024) - - @tailrec - def transfer(): Unit = { - read = is.read(buf, 0, buf.length) - if (read != -1) { - os.write(buf, 0, read) - transfer() - } - } - transfer() + transfer(is, os) os.toByteArray @@ -135,6 +124,10 @@ object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] { // only possible when the user requests the response as a stream of // Nothing. Oh well ... throw new IllegalStateException() + + case ResponseAsFile(input, overwrite) => + ResponseAs.saveFile(input, is, overwrite) + } } diff --git a/core/src/main/scala/com/softwaremill/sttp/model/ResponseAs.scala b/core/src/main/scala/com/softwaremill/sttp/model/ResponseAs.scala index 3862483..24b9b2b 100644 --- a/core/src/main/scala/com/softwaremill/sttp/model/ResponseAs.scala +++ b/core/src/main/scala/com/softwaremill/sttp/model/ResponseAs.scala @@ -1,8 +1,10 @@ package com.softwaremill.sttp.model +import java.io.{File, FileOutputStream, IOException, InputStream} import java.net.URLDecoder +import java.nio.file.Path -import com.softwaremill.sttp.MonadError +import com.softwaremill.sttp.{MonadError, transfer} import scala.collection.immutable.Seq import scala.language.higherKinds @@ -38,6 +40,9 @@ case class MappedResponseAs[T, T2, S](raw: BasicResponseAs[T, S], g: T => T2) MappedResponseAs[T, T3, S](raw, g andThen f) } +case class ResponseAsFile(input: File, overwrite: Boolean) + extends BasicResponseAs[File, Nothing] + object ResponseAs { private[sttp] def parseParams(s: String, encoding: String): Seq[(String, String)] = { @@ -52,6 +57,23 @@ object ResponseAs { }) } + private[sttp] def saveFile(file: File, + is: InputStream, + overwrite: Boolean): File = { + if (!file.exists()) { + file.getParentFile.mkdirs() + file.createNewFile() + } else if (!overwrite) { + throw new IOException( + s"File ${file.getAbsolutePath} exists - overwriting prohibited") + } + + val os = new FileOutputStream(file) + + transfer(is, os) + file + } + /** * Handles responses according to the given specification when basic * response specifications can be handled eagerly, that is without diff --git a/core/src/main/scala/com/softwaremill/sttp/package.scala b/core/src/main/scala/com/softwaremill/sttp/package.scala index daba574..884d2f9 100644 --- a/core/src/main/scala/com/softwaremill/sttp/package.scala +++ b/core/src/main/scala/com/softwaremill/sttp/package.scala @@ -1,12 +1,12 @@ package com.softwaremill -import java.io.{File, InputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.file.Path import com.softwaremill.sttp.model._ -import scala.annotation.implicitNotFound +import scala.annotation.{implicitNotFound, tailrec} import scala.language.higherKinds import scala.collection.immutable.Seq @@ -88,6 +88,14 @@ package object sttp { def asStream[S]: ResponseAs[S, S] = ResponseAsStream[S, S]() + def asFile(file: File, + overwrite: Boolean = false): ResponseAs[File, Nothing] = + ResponseAsFile(file, overwrite) + + def asPath(path: Path, + overwrite: Boolean = false): ResponseAs[Path, Nothing] = + ResponseAsFile(path.toFile, overwrite).map(_.toPath) + // multi part factory methods /** @@ -159,6 +167,22 @@ package object sttp { private[sttp] def contentTypeWithEncoding(ct: String, enc: String) = s"$ct; charset=$enc" + private[sttp] def transfer(is: InputStream, os: OutputStream) { + var read = 0 + val buf = new Array[Byte](1024) + + @tailrec + def transfer(): Unit = { + read = is.read(buf, 0, buf.length) + if (read != -1) { + os.write(buf, 0, read) + transfer() + } + } + + transfer() + } + // uri interpolator implicit class UriContext(val sc: StringContext) extends AnyVal { -- cgit v1.2.3