aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/com/softwaremill/sttp/HttpConnectionSttpHandler.scala
blob: fdbb32221b5cc1f6f8144b3aeadb83775269d8b8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package com.softwaremill.sttp

import java.io.{
  ByteArrayOutputStream,
  InputStream,
  OutputStream,
  OutputStreamWriter
}
import java.net.HttpURLConnection
import java.nio.channels.Channels
import java.nio.file.Files

import com.softwaremill.sttp.model._

import scala.annotation.tailrec
import scala.io.Source

object HttpConnectionSttpHandler extends SttpHandler[Id, Nothing] {
  override def send[T](r: Request,
                       responseAs: ResponseAs[T, Nothing]): Response[T] = {
    val c = r.uri.toURL.openConnection().asInstanceOf[HttpURLConnection]
    c.setRequestMethod(r.method.m)
    r.headers.foreach { case (k, v) => c.setRequestProperty(k, v) }
    c.setDoInput(true)
    setBody(r.body, c)

    val status = c.getResponseCode
    Response(status, readResponse(c.getInputStream, responseAs))
  }

  private def setBody(body: RequestBody, c: HttpURLConnection): Unit = {
    if (body != NoBody) c.setDoOutput(true)

    def copyStream(in: InputStream, out: OutputStream): Unit = {
      val buf = new Array[Byte](1024)

      @tailrec
      def doCopy(): Unit = {
        val read = in.read(buf)
        if (read != -1) {
          out.write(buf, 0, read)
          doCopy()
        }
      }

      doCopy()
    }

    body match {
      case NoBody => // skip

      case StringBody(b, encoding) =>
        val writer = new OutputStreamWriter(c.getOutputStream, encoding)
        try writer.write(b)
        finally writer.close()

      case ByteArrayBody(b) =>
        c.getOutputStream.write(b)

      case ByteBufferBody(b) =>
        val channel = Channels.newChannel(c.getOutputStream)
        try channel.write(b)
        finally channel.close()

      case InputStreamBody(b) =>
        copyStream(b, c.getOutputStream)

      case FileBody(b) =>
        Files.copy(b.toPath, c.getOutputStream)

      case PathBody(b) =>
        Files.copy(b, c.getOutputStream)

      case SerializableBody(f, t) =>
        setBody(f(t), c)
    }
  }

  private def readResponse[T](is: InputStream,
                              responseAs: ResponseAs[T, Nothing]): T =
    responseAs match {
      case IgnoreResponse =>
        @tailrec def consume(): Unit = if (is.read() != -1) consume()
        consume()

      case ResponseAsString(enc) =>
        Source.fromInputStream(is, enc).mkString

      case ResponseAsByteArray =>
        val os = new ByteArrayOutputStream
        var read = 0
        val buf = new Array[Byte](1024)

        @tailrec
        def transfer(): Unit = {
          read = is.read(buf, 0, buf.length)
          if (read != -1) {
            os.write(buf, 0, read)
            transfer()
          }
        }

        transfer()

        os.toByteArray

      case ResponseAsStream() =>
        // only possible when the user requests the response as a stream of Nothing. Oh well ...
        throw new IllegalStateException()
    }
}