aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/com/softwaremill/sttp/Uri.scala
blob: 7ae001318d630c412d3f78c51c40b2815a6f440a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
package com.softwaremill.sttp

import java.net.URLEncoder
import java.net.URI
import Uri.{QueryFragment, QueryFragmentEncoding, UserInfo}
import Uri.QueryFragment.{KeyValue, Plain, Value}

import scala.annotation.tailrec
import scala.collection.immutable.Seq

/**
  * A [[https://en.wikipedia.org/wiki/Uniform_Resource_Identifier URI]].
  * All components (scheme, host, query, ...) are stored unencoded, and
  * become encoded upon serialization (using [[toString]]).
  *
  * @param queryFragments Either key-value pairs, single values, or plain
  * query fragments. Key value pairs will be serialized as `k=v`, and blocks
  * of key-value pairs/single values will be combined using `&`. Note that no
  * `&` or other separators are added around plain query fragments - if
  * required, they need to be added manually as part of the plain query
  * fragment.
  */
case class Uri(scheme: String,
               userInfo: Option[UserInfo],
               host: String,
               port: Option[Int],
               path: Seq[String],
               queryFragments: Seq[QueryFragment],
               fragment: Option[String]) {

  private val AllowedSchemeCharacters = "[a-zA-Z][a-zA-Z0-9+-.]*".r

  require(host.nonEmpty, "Host cannot be empty")
  require(AllowedSchemeCharacters.unapplySeq(scheme).isDefined,
          "Scheme can only contain alphanumeric characters, +, - and .")

  def scheme(s: String): Uri = this.copy(scheme = s)

  def userInfo(username: String): Uri =
    this.copy(userInfo = Some(UserInfo(username, None)))

  def userInfo(username: String, password: String): Uri =
    this.copy(userInfo = Some(UserInfo(username, Some(password))))

  def host(h: String): Uri = this.copy(host = h)

  def port(p: Int): Uri = this.copy(port = Some(p))

  def port(p: Option[Int]): Uri = this.copy(port = p)

  def path(p: String): Uri = {
    // removing the leading slash, as it is added during serialization anyway
    val pWithoutLeadingSlash = if (p.startsWith("/")) p.substring(1) else p
    val ps = pWithoutLeadingSlash.split("/", -1).toList
    this.copy(path = ps)
  }

  def path(p1: String, p2: String, ps: String*): Uri =
    this.copy(path = p1 :: p2 :: ps.toList)

  def path(ps: scala.collection.Seq[String]): Uri = this.copy(path = ps.toList)

  /**
    * Adds the given parameter to the query.
    */
  def param(k: String, v: String): Uri = params(k -> v)

  /**
    * Adds the given parameters to the query.
    */
  def params(ps: Map[String, String]): Uri = params(ps.toSeq: _*)

  /**
    * Adds the given parameters to the query.
    */
  def params(ps: (String, String)*): Uri = {
    this.copy(queryFragments = queryFragments ++ ps.map {
      case (k, v) => KeyValue(k, v)
    })
  }

  def paramsMap: Map[String, String] = paramsSeq.toMap

  def paramsSeq: Seq[(String, String)] = queryFragments.collect {
    case KeyValue(k, v, _, _) => k -> v
  }

  def queryFragment(qf: QueryFragment): Uri =
    this.copy(queryFragments = queryFragments :+ qf)

  def fragment(f: String): Uri = this.copy(fragment = Some(f))

  def fragment(f: Option[String]): Uri = this.copy(fragment = f)

  def toJavaUri: URI = new URI(toString())

  override def toString: String = {
    def encodeUserInfo(ui: UserInfo): String =
      encode(Rfc3986.UserInfo)(ui.username) + ui.password.fold("")(":" + encode(Rfc3986.UserInfo)(_))

    @tailrec
    def encodeQueryFragments(qfs: List[QueryFragment], previousWasPlain: Boolean, sb: StringBuilder): String =
      qfs match {
        case Nil => sb.toString()

        case Plain(v, re) :: t =>
          encodeQueryFragments(t, previousWasPlain = true, sb.append(encodeQuery(v, re)))

        case Value(v, re) :: t =>
          if (!previousWasPlain) sb.append("&")
          sb.append(encodeQuery(v, re))
          encodeQueryFragments(t, previousWasPlain = false, sb)

        case KeyValue(k, v, reK, reV) :: t =>
          if (!previousWasPlain) sb.append("&")
          sb.append(encodeQuery(k, reK)).append("=").append(encodeQuery(v, reV))
          encodeQueryFragments(t, previousWasPlain = false, sb)
      }

    val schemeS = encode(Rfc3986.Scheme)(scheme)
    val userInfoS = userInfo.fold("")(encodeUserInfo(_) + "@")
    val hostS = encodeHost
    val portS = port.fold("")(":" + _)
    val pathPrefixS = if (path.isEmpty) "" else "/"
    val pathS = path.map(encode(Rfc3986.PathSegment)).mkString("/")
    val queryPrefixS = if (queryFragments.isEmpty) "" else "?"

    val queryS = encodeQueryFragments(queryFragments.toList, previousWasPlain = true, new StringBuilder())

    // https://stackoverflow.com/questions/2053132/is-a-colon-safe-for-friendly-url-use/2053640#2053640
    val fragS = fragment.fold("")("#" + encode(Rfc3986.Fragment)(_))

    s"$schemeS://$userInfoS$hostS$portS$pathPrefixS$pathS$queryPrefixS$queryS$fragS"
  }

  private def encodeQuery(s: String, e: QueryFragmentEncoding): String =
    e match {
      case QueryFragmentEncoding.All => URLEncoder.encode(s, "UTF-8")
      case QueryFragmentEncoding.Standard =>
        encode(Rfc3986.QueryNoStandardDelims, spaceAsPlus = true, encodePlus = true)(s)
      case QueryFragmentEncoding.Relaxed =>
        encode(Rfc3986.Query, spaceAsPlus = true)(s)
    }

  private object Rfc3986 {
    val AlphaNum: Set[Char] =
      (('a' to 'z') ++ ('A' to 'Z') ++ ('0' to '9')).toSet
    val Unreserved: Set[Char] = AlphaNum ++ Set('-', '.', '_', '~')
    val SubDelims: Set[Char] =
      Set('!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=')
    val PChar: Set[Char] = Unreserved ++ SubDelims ++ Set(':', '@')

    val Scheme: Set[Char] = AlphaNum ++ Set('+', '-', '.')
    val UserInfo: Set[Char] = Unreserved ++ SubDelims
    val Host: Set[Char] = Unreserved ++ SubDelims
    val PathSegment: Set[Char] = PChar
    val Query: Set[Char] = PChar ++ Set('/', '?')
    val Fragment: Set[Char] = Query

    val QueryNoStandardDelims: Set[Char] = Query -- Set('&', '=')
  }

  private val IpV6Pattern = "[0-9a-fA-F:]+".r

  private def encodeHost: String =
    host match {
      case IpV6Pattern() => s"[$host]"
      case _             => encode(Rfc3986.Host)(java.net.IDN.toASCII(host))
    }

  /**
    * @param spaceAsPlus In the query, space is encoded as a `+`. In other
    * contexts, it should be %-encoded as `%20`.
    * @param encodePlus Should `+` (which is the encoded form of space
    * in the query) be %-encoded.
    */
  private def encode(allowedCharacters: Set[Char], spaceAsPlus: Boolean = false, encodePlus: Boolean = false)(
      s: String): String = {
    val sb = new StringBuilder()
    // based on https://gist.github.com/teigen/5865923
    for (c <- s) {
      if (c == '+' && encodePlus) sb.append("%2B") // #48
      else if (allowedCharacters(c)) sb.append(c)
      else if (c == ' ' && spaceAsPlus) sb.append('+')
      else {
        for (b <- c.toString.getBytes("UTF-8")) {
          sb.append("%")
          sb.append("%02X".format(b))
        }
      }
    }
    sb.toString
  }
}

object Uri {
  def apply(host: String): Uri =
    Uri("http", None, host, None, Vector.empty, Vector.empty, None)
  def apply(host: String, port: Int): Uri =
    Uri("http", None, host, Some(port), Vector.empty, Vector.empty, None)
  def apply(host: String, port: Int, path: Seq[String]): Uri =
    Uri("http", None, host, Some(port), path, Vector.empty, None)
  def apply(scheme: String, host: String): Uri =
    Uri(scheme, None, host, None, Vector.empty, Vector.empty, None)
  def apply(scheme: String, host: String, port: Int): Uri =
    Uri(scheme, None, host, Some(port), Vector.empty, Vector.empty, None)
  def apply(scheme: String, host: String, port: Int, path: Seq[String]): Uri =
    Uri(scheme, None, host, Some(port), path, Vector.empty, None)
  def apply(scheme: String, host: String, path: Seq[String]): Uri =
    Uri(scheme, None, host, None, path, Vector.empty, None)
  def apply(javaUri: URI): Uri = uri"${javaUri.toString}"

  sealed trait QueryFragment
  object QueryFragment {

    /**
      * @param keyEncoding See [[Plain.encoding]]
      * @param valueEncoding See [[Plain.encoding]]
      */
    case class KeyValue(k: String,
                        v: String,
                        keyEncoding: QueryFragmentEncoding = QueryFragmentEncoding.Standard,
                        valueEncoding: QueryFragmentEncoding = QueryFragmentEncoding.Standard)
        extends QueryFragment

    /**
      * A query fragment which contains only the value, without a key.
      */
    case class Value(v: String, relaxedEncoding: QueryFragmentEncoding = QueryFragmentEncoding.Standard)
        extends QueryFragment

    /**
      * A query fragment which will be inserted into the query, without and
      * preceding or following separators. Allows constructing query strings
      * which are not (only) &-separated key-value pairs.
      *
      * @param encoding Should reserved characters (in the RFC3986 sense),
      * which are allowed in the query string, but can be also escaped be
      * left unchanged. These characters are:
      * {{{
      * /?:@-._~!$&()*+,;=
      * }}}
      * See:
      * [[https://stackoverflow.com/questions/2322764/what-characters-must-be-escaped-in-an-http-query-string]]
      * [[https://stackoverflow.com/questions/2366260/whats-valid-and-whats-not-in-a-uri-query]]
      */
    case class Plain(v: String, encoding: QueryFragmentEncoding = QueryFragmentEncoding.Standard) extends QueryFragment
  }

  sealed trait QueryFragmentEncoding
  object QueryFragmentEncoding {

    /**
      * Encodes all reserved characters using [[java.net.URLEncoder.encode()]].
      */
    case object All extends QueryFragmentEncoding

    /**
      * Encodes only the `&` and `=` reserved characters, which are usually
      * used to separate query parameter names and values.
      */
    case object Standard extends QueryFragmentEncoding

    /**
      * Doesn't encode any of the reserved characters, leaving intact all
      * characters allow in the query string as defined by RFC3986.
      */
    case object Relaxed extends QueryFragmentEncoding
  }

  case class UserInfo(username: String, password: Option[String])
}