summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/reflect/MacroImplementations.scala
blob: 4e8f02084db2b0a261a61b65c5cac842c78a1b3c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package scala.tools.reflect

import scala.reflect.macros.runtime.Context
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.Stack

abstract class MacroImplementations {
  val c: Context

  import c.universe._
  import definitions._

  def macro_StringInterpolation_f(parts: List[Tree], args: List[Tree], origApplyPos: c.universe.Position): Tree = {
    // the parts all have the same position information (as the expression is generated by the compiler)
    // the args have correct position information

    // the following conditions can only be violated if invoked directly
    if (parts.length != args.length + 1) {
      if(parts.length == 0)
        c.abort(c.prefix.tree.pos, "too few parts")
      else if(args.length + 1 < parts.length)
        c.abort(if(args.length==0) c.enclosingPosition else args.last.pos,
            "too few arguments for interpolated string")
      else
        c.abort(args(parts.length-1).pos,
            "too many arguments for interpolated string")
    }
    val stringParts = parts map {
      case Literal(Constant(s: String)) => s
      case _ => throw new IllegalArgumentException("argument parts must be a list of string literals")
    }

    val pi = stringParts.iterator
    val bldr = new java.lang.StringBuilder
    val evals = ListBuffer[ValDef]()
    val ids = ListBuffer[Ident]()
    val argsStack = Stack(args : _*)

    def defval(value: Tree, tpe: Type): Unit = {
      val freshName = newTermName(c.freshName("arg$"))
      evals += ValDef(Modifiers(), freshName, TypeTree(tpe) setPos value.pos.focus, value) setPos value.pos
      ids += Ident(freshName)
    }

    def isFlag(ch: Char): Boolean = {
      ch match {
        case '-' | '#' | '+' | ' ' | '0' | ',' | '(' => true
        case _ => false
      }
    }

    def checkType(arg: Tree, variants: Type*): Option[Type] = {
      variants.find(arg.tpe <:< _).orElse(
        variants.find(c.inferImplicitView(arg, arg.tpe, _) != EmptyTree).orElse(
            Some(variants(0))
        )
      )
    }

    val stdContextTags = new { val tc: c.type = c } with StdContextTags
    import stdContextTags._

    def conversionType(ch: Char, arg: Tree): Option[Type] = {
      ch match {
        case 'b' | 'B' =>
          if(arg.tpe <:< NullTpe) Some(NullTpe) else Some(BooleanTpe)
        case 'h' | 'H' =>
          Some(AnyTpe)
        case 's' | 'S' =>
          Some(AnyTpe)
        case 'c' | 'C' =>
          checkType(arg, CharTpe, ByteTpe, ShortTpe, IntTpe)
        case 'd' | 'o' | 'x' | 'X' =>
          checkType(arg, IntTpe, LongTpe, ByteTpe, ShortTpe, tagOfBigInt.tpe)
        case 'e' | 'E' | 'g' | 'G' | 'f' | 'a' | 'A'  =>
          checkType(arg, DoubleTpe, FloatTpe, tagOfBigDecimal.tpe)
        case 't' | 'T' =>
          checkType(arg, LongTpe, tagOfCalendar.tpe, tagOfDate.tpe)
        case _ => None
      }
    }

    def copyString(first: Boolean): Unit = {
      val str = StringContext.treatEscapes(pi.next())
      val strLen = str.length
      val strIsEmpty = strLen == 0
      var start = 0
      var idx = 0

      if (!first) {
        val arg = argsStack.pop()
        if (strIsEmpty || (str charAt 0) != '%') {
          bldr append "%s"
          defval(arg, AnyTpe)
        } else {
          // PRE str is not empty and str(0) == '%'
          // argument index parameter is not allowed, thus parse
          //    [flags][width][.precision]conversion
          var pos = 1
          while(pos < strLen && isFlag(str charAt pos)) pos += 1
          while(pos < strLen && Character.isDigit(str charAt pos)) pos += 1
          if(pos < strLen && str.charAt(pos) == '.') { pos += 1
            while(pos < strLen && Character.isDigit(str charAt pos)) pos += 1
          }
          if(pos < strLen) {
            conversionType(str charAt pos, arg) match {
              case Some(tpe) => defval(arg, tpe)
              case None => c.error(arg.pos, "illegal conversion character")
            }
          } else {
            // TODO: place error message on conversion string
            c.error(arg.pos, "wrong conversion string")
          }
        }
        idx = 1
      }
      if (!strIsEmpty) {
        val len = str.length
        while (idx < len) {
          def notPercentN = str(idx) != '%' || (idx + 1 < len && str(idx + 1) != 'n')
          if (str(idx) == '%' && notPercentN) {
            bldr append (str substring (start, idx)) append "%%"
            start = idx + 1
          }
          idx += 1
        }
        bldr append (str substring (start, idx))
      }
    }

    copyString(first = true)
    while (pi.hasNext) {
      copyString(first = false)
    }

    val fstring = bldr.toString
//  val expr = c.reify(fstring.format((ids.map(id => Expr(id).eval)) : _*))
//  https://issues.scala-lang.org/browse/SI-5824, therefore
    val expr =
      Apply(
        Select(
          Literal(Constant(fstring)),
          newTermName("format")),
        List(ids: _* )
      )

    Block(evals.toList, atPos(origApplyPos.focus)(expr)) setPos origApplyPos.makeTransparent
  }

}