aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/com/softwaremill/sttp/HttpURLConnectionSttpHandler.scala
blob: dd208f416efd0765acba68c871e3629c2f7a088e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package com.softwaremill.sttp

import java.io._
import java.net.{HttpURLConnection, URL}
import java.nio.channels.Channels
import java.nio.charset.CharacterCodingException
import java.nio.file.Files
import java.util.zip.{GZIPInputStream, InflaterInputStream}

import com.softwaremill.sttp.model._

import scala.annotation.tailrec
import scala.io.Source
import scala.collection.JavaConverters._

object HttpURLConnectionSttpHandler extends SttpHandler[Id, Nothing] {
  override def send[T](r: Request[T, Nothing]): Response[T] = {
    val c =
      new URL(r.uri.toString).openConnection().asInstanceOf[HttpURLConnection]
    c.setRequestMethod(r.method.m)
    r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) }
    c.setDoInput(true)
    setBody(r.body, c)

    try {
      val is = c.getInputStream
      readResponse(c, is, r.responseAs)
    } catch {
      case e: CharacterCodingException     => throw e
      case e: UnsupportedEncodingException => throw e
      case _: IOException if c.getResponseCode != -1 =>
        readResponse(c, c.getErrorStream, r.responseAs)
    }
  }

  override def responseMonad: MonadError[Id] = IdMonad

  private def setBody(body: RequestBody[Nothing], c: HttpURLConnection): Unit = {
    if (body != NoBody) c.setDoOutput(true)

    def copyStream(in: InputStream, out: OutputStream): Unit = {
      val buf = new Array[Byte](1024)

      @tailrec
      def doCopy(): Unit = {
        val read = in.read(buf)
        if (read != -1) {
          out.write(buf, 0, read)
          doCopy()
        }
      }

      doCopy()
    }

    body match {
      case NoBody => // skip

      case StringBody(b, encoding, _) =>
        val writer = new OutputStreamWriter(c.getOutputStream, encoding)
        try writer.write(b)
        finally writer.close()

      case ByteArrayBody(b, _) =>
        c.getOutputStream.write(b)

      case ByteBufferBody(b, _) =>
        val channel = Channels.newChannel(c.getOutputStream)
        try channel.write(b)
        finally channel.close()

      case InputStreamBody(b, _) =>
        copyStream(b, c.getOutputStream)

      case PathBody(b, _) =>
        Files.copy(b, c.getOutputStream)

      case StreamBody(s) =>
        // we have an instance of nothing - everything's possible!
        s
    }
  }

  private def readResponse[T](
      c: HttpURLConnection,
      is: InputStream,
      responseAs: ResponseAs[T, Nothing]): Response[T] = {

    val headers = c.getHeaderFields.asScala.toVector
      .filter(_._1 != null)
      .flatMap { case (k, vv) => vv.asScala.map((k, _)) }
    val contentEncoding = Option(c.getHeaderField(ContentEncodingHeader))
    Response(readResponseBody(wrapInput(contentEncoding, is), responseAs),
             c.getResponseCode,
             headers)
  }

  private def readResponseBody[T](is: InputStream,
                                  responseAs: ResponseAs[T, Nothing]): T = {

    def asString(enc: String) = Source.fromInputStream(is, enc).mkString

    responseAs match {
      case MappedResponseAs(raw, g) => g(readResponseBody(is, raw))

      case IgnoreResponse =>
        @tailrec def consume(): Unit = if (is.read() != -1) consume()
        consume()

      case ResponseAsString(enc) =>
        asString(enc)

      case ResponseAsByteArray =>
        val os = new ByteArrayOutputStream

        transfer(is, os)

        os.toByteArray

      case ResponseAsStream() =>
        // only possible when the user requests the response as a stream of
        // Nothing. Oh well ...
        throw new IllegalStateException()

      case ResponseAsFile(input, overwrite) =>
        ResponseAs.saveFile(input, is, overwrite)

    }
  }

  private def wrapInput(contentEncoding: Option[String],
                        is: InputStream): InputStream =
    contentEncoding.map(_.toLowerCase) match {
      case None            => is
      case Some("gzip")    => new GZIPInputStream(is)
      case Some("deflate") => new InflaterInputStream(is)
      case Some(ce) =>
        throw new UnsupportedEncodingException(s"Unsupported encoding: $ce")
    }
}