aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala
blob: 50f6d7613480ba2f568645f867f891ad702b5d6b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package com.softwaremill.sttp

import java.io._
import java.net.HttpURLConnection
import java.nio.channels.Channels
import java.nio.file.Files

import com.softwaremill.sttp.model._

import scala.annotation.tailrec
import scala.io.Source
import scala.collection.JavaConverters._

object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] {
  override def send[T](r: Request,
                       responseAs: ResponseAs[T, Nothing]): Response[T] = {
    val c = r.uri.toURL.openConnection().asInstanceOf[HttpURLConnection]
    c.setRequestMethod(r.method.m)
    r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) }
    c.setDoInput(true)
    setBody(r.body, c)

    try {
      val is = c.getInputStream
      readResponse(c, is, responseAs)
    } catch {
      case _: IOException if c.getResponseCode != -1 =>
        readResponse(c, c.getErrorStream, responseAs)
    }
  }

  private def setBody(body: RequestBody, c: HttpURLConnection): Unit = {
    if (body != NoBody) c.setDoOutput(true)

    def copyStream(in: InputStream, out: OutputStream): Unit = {
      val buf = new Array[Byte](1024)

      @tailrec
      def doCopy(): Unit = {
        val read = in.read(buf)
        if (read != -1) {
          out.write(buf, 0, read)
          doCopy()
        }
      }

      doCopy()
    }

    body match {
      case NoBody => // skip

      case StringBody(b, encoding) =>
        val writer = new OutputStreamWriter(c.getOutputStream, encoding)
        try writer.write(b)
        finally writer.close()

      case ByteArrayBody(b) =>
        c.getOutputStream.write(b)

      case ByteBufferBody(b) =>
        val channel = Channels.newChannel(c.getOutputStream)
        try channel.write(b)
        finally channel.close()

      case InputStreamBody(b) =>
        copyStream(b, c.getOutputStream)

      case PathBody(b) =>
        Files.copy(b, c.getOutputStream)

      case SerializableBody(f, t) =>
        setBody(f(t), c)
    }
  }

  private def readResponse[T](
      c: HttpURLConnection,
      is: InputStream,
      responseAs: ResponseAs[T, Nothing]): Response[T] = {

    val headers = c.getHeaderFields.asScala.toVector
      .filter(_._1 != null)
      .flatMap { case (k, vv) => vv.asScala.map((k, _)) }
    Response(readResponseBody(is, responseAs), c.getResponseCode, headers)
  }

  private def readResponseBody[T](is: InputStream,
                                  responseAs: ResponseAs[T, Nothing]): T = {
    def asString(enc: String) = Source.fromInputStream(is, enc).mkString

    responseAs match {
      case IgnoreResponse =>
        @tailrec def consume(): Unit = if (is.read() != -1) consume()

        consume()

      case ResponseAsString(enc) =>
        asString(enc)

      case ResponseAsByteArray =>
        val os = new ByteArrayOutputStream
        var read = 0
        val buf = new Array[Byte](1024)

        @tailrec
        def transfer(): Unit = {
          read = is.read(buf, 0, buf.length)
          if (read != -1) {
            os.write(buf, 0, read)
            transfer()
          }
        }

        transfer()

        os.toByteArray

      case r @ ResponseAsParams(enc) =>
        r.parse(asString(enc))

      case ResponseAsStream() =>
        // only possible when the user requests the response as a stream of
        // Nothing. Oh well ...
        throw new IllegalStateException()
    }
  }
}