summaryrefslogblamecommitdiff
path: root/src/compiler/scala/tools/reflect/FormatInterpolator.scala
blob: bbb298767968711455c29458d1f6ad1353457078 (plain) (tree)
1
2
3
4
5
6
7
8
9

                           
                                           




                                                     
                            


                                   
                                          


                                     
                         



                                                                        
                                                                     
                                                   

                                                    


                                                         










                                                                                      
                                                       










































                                                                                                                        











                                                                      

                                                                                                             





















                                                                                      
                                                                                            













                                                                     

















































                                                                                                                         
                                                                      


                                                                       





                                                                                   


                 



                                               




                                                                        

























































                                                                                                     
                                    

















                                                                                          
                                       








































































































                                                                                                                                   
package scala.tools.reflect

import scala.reflect.macros.runtime.Context
import scala.collection.mutable.{ ListBuffer, Stack }
import scala.reflect.internal.util.Position
import scala.PartialFunction.cond
import scala.util.matching.Regex.Match

import java.util.Formattable

abstract class FormatInterpolator {
  val c: Context
  val global: c.universe.type = c.universe

  import c.universe.{ Match => _, _ }
  import definitions._
  import treeInfo.Applied

  @inline private def truly(body: => Unit): Boolean = { body ; true }
  @inline private def falsely(body: => Unit): Boolean = { body ; false }

  //private def fail(msg: String) = c.abort(c.enclosingPosition, msg)
  private def bail(msg: String) = global.abort(msg)

  def interpolate: Tree = c.macroApplication match {
    //case q"$_(..$parts).f(..$args)" =>
    case Applied(Select(Apply(_, parts), _), _, argss) =>
      val args = argss.flatten
      def badlyInvoked = (parts.length != args.length + 1) && truly {
        def because(s: String) = s"too $s arguments for interpolated string"
        val (p, msg) =
          if (parts.length == 0) (c.prefix.tree.pos, "there are no parts")
          else if (args.length + 1 < parts.length)
            (if (args.isEmpty) c.enclosingPosition else args.last.pos, because("few"))
          else (args(parts.length-1).pos, because("many"))
        c.abort(p, msg)
      }
      if (badlyInvoked) c.macroApplication else interpolated(parts, args)
    case other =>
      bail(s"Unexpected application ${showRaw(other)}")
      other
  }

  /** Every part except the first must begin with a conversion for
   *  the arg that preceded it. If the conversion is missing, "%s"
   *  is inserted.
   *
   *  In any other position, the only permissible conversions are
   *  the literals (%% and %n) or an index reference (%1$ or %<).
   *
   *  A conversion specifier has the form:
   *
   *  [index$][flags][width][.precision]conversion
   *
   *  1) "...${smth}" => okay, equivalent to "...${smth}%s"
   *  2) "...${smth}blahblah" => okay, equivalent to "...${smth}%sblahblah"
   *  3) "...${smth}%" => error
   *  4) "...${smth}%n" => okay, equivalent to "...${smth}%s%n"
   *  5) "...${smth}%%" => okay, equivalent to "...${smth}%s%%"
   *  6) "...${smth}[%legalJavaConversion]" => okay*
   *  7) "...${smth}[%illegalJavaConversion]" => error
   *  *Legal according to [[http://docs.oracle.com/javase/1.5.0/docs/api/java/util/Formatter.html]]
   */
  def interpolated(parts: List[Tree], args: List[Tree]) = {
    val fstring  = new StringBuilder
    val evals    = ListBuffer[ValDef]()
    val ids      = ListBuffer[Ident]()
    val argStack = Stack(args: _*)

    // create a tmp val and add it to the ids passed to format
    def defval(value: Tree, tpe: Type): Unit = {
      val freshName = TermName(c.freshName("arg$"))
      evals += ValDef(Modifiers(), freshName, TypeTree(tpe) setPos value.pos.focus, value) setPos value.pos
      ids += Ident(freshName)
    }
    // Append the nth part to the string builder, possibly prepending an omitted %s first.
    // Sanity-check the % fields in this part.
    def copyPart(part: Tree, n: Int): Unit = {
      import SpecifierGroups.{ Spec, Index }
      val s0 = part match {
        case Literal(Constant(x: String)) => x
        case _ => throw new IllegalArgumentException("internal error: argument parts must be a list of string literals")
      }
      def escapeHatch: PartialFunction[Throwable, String] = {
        // trailing backslash, octal escape, or other
        case e: StringContext.InvalidEscapeException =>
          def errPoint = part.pos withPoint (part.pos.point + e.index)
          def octalOf(c: Char) = Character.digit(c, 8)
          def alt = {
            def altOf(i: Int) = i match {
              case '\b' => "\\b"
              case '\t' => "\\t"
              case '\n' => "\\n"
              case '\f' => "\\f"
              case '\r' => "\\r"
              case '\"' => "$" /* avoid lint warn */ +
                "{'\"'} or a triple-quoted literal \"\"\"with embedded \" or \\u0022\"\"\""  // $" in future?
              case '\'' => "'"
              case '\\' => """\\"""
              case x    => "\\u%04x" format x
            }
            val suggest = {
              val r = "([0-7]{1,3}).*".r
              (s0 drop e.index + 1) match {
                case r(n) => altOf { (0 /: n) { case (a, o) => (8 * a) + (o - '0') } }
                case _    => ""
              }
            }
            val txt =
              if ("" == suggest) ""
              else s", use $suggest instead"
            txt
          }
          def badOctal = {
            def msg(what: String) = s"Octal escape literals are $what$alt."
            if (settings.future) {
              c.error(errPoint, msg("unsupported"))
              s0
            } else {
              currentRun.reporting.deprecationWarning(errPoint, msg("deprecated"), "2.11.0")
              try StringContext.treatEscapes(s0) catch escapeHatch
            }
          }
          if (e.index == s0.length - 1) {
            c.error(errPoint, """Trailing '\' escapes nothing.""")
            s0
          } else if (octalOf(s0(e.index + 1)) >= 0) {
            badOctal
          } else {
            c.error(errPoint, e.getMessage)
            s0
          }
      }
      val s  = try StringContext.processEscapes(s0) catch escapeHatch
      val ms = fpat findAllMatchIn s

      def errorLeading(op: Conversion) = op.errorAt(Spec, s"conversions must follow a splice; ${Conversion.literalHelp}")

      def first = n == 0
      // a conversion for the arg is required
      if (!first) {
        val arg = argStack.pop()
        def s_%() = {
          fstring append "%s"
          defval(arg, AnyTpe)
        }
        def accept(op: Conversion) = {
          if (!op.isLeading) errorLeading(op)
          op.accepts(arg) match {
            case Some(tpe) => defval(arg, tpe)
            case None      =>
          }
        }
        if (ms.hasNext) {
          Conversion(ms.next, part.pos, args.size) match {
            case Some(op) if op.isLiteral => s_%()
            case Some(op) if op.indexed =>
              if (op.index map (_ == n) getOrElse true) accept(op)
              else {
                // either some other arg num, or '<'
                c.warning(op.groupPos(Index), "Index is not this arg")
                s_%()
              }
            case Some(op) => accept(op)
            case None     =>
          }
        } else s_%()
      }
      // any remaining conversions must be either literals or indexed
      while (ms.hasNext) {
        Conversion(ms.next, part.pos, args.size) match {
          case Some(op) if first && op.hasFlag('<')   => op.badFlag('<', "No last arg")
          case Some(op) if op.isLiteral || op.indexed => // OK
          case Some(op) => errorLeading(op)
          case None     =>
        }
      }
      fstring append s
    }

    parts.zipWithIndex foreach {
      case (part, n) => copyPart(part, n)
    }

    //q"{..$evals; new StringOps(${fstring.toString}).format(..$ids)}"
    val format = fstring.toString
    if (ids.isEmpty && !format.contains("%")) Literal(Constant(format))
    else {
      val scalaPackage = Select(Ident(nme.ROOTPKG), TermName("scala"))
      val newStringOps = Select(
        New(Select(Select(Select(scalaPackage,
          TermName("collection")), TermName("immutable")), TypeName("StringOps"))),
        termNames.CONSTRUCTOR
      )
      val expr =
        Apply(
          Select(
            Apply(
              newStringOps,
              List(Literal(Constant(format)))),
            TermName("format")),
          ids.toList
        )
      val p = c.macroApplication.pos
      Block(evals.toList, atPos(p.focus)(expr)) setPos p.makeTransparent
    }
  }

  val fpat = """%(?:(\d+)\$)?([-#+ 0,(\<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r
  object SpecifierGroups extends Enumeration { val Spec, Index, Flags, Width, Precision, CC = Value }

  val stdContextTags = new { val tc: c.type = c } with StdContextTags
  import stdContextTags._
  val tagOfFormattable = typeTag[Formattable]

  /** A conversion specifier matched by `m` in the string part at `pos`,
   *  with `argc` arguments to interpolate.
   */
  sealed trait Conversion {
    def m: Match
    def pos: Position
    def argc: Int

    import SpecifierGroups.{ Value => SpecGroup, _ }
    private def maybeStr(g: SpecGroup) = Option(m group g.id)
    private def maybeInt(g: SpecGroup) = maybeStr(g) map (_.toInt)
    val index: Option[Int]     = maybeInt(Index)
    val flags: Option[String]  = maybeStr(Flags)
    val width: Option[Int]     = maybeInt(Width)
    val precision: Option[Int] = maybeStr(Precision) map (_.drop(1).toInt)
    val op: String             = maybeStr(CC) getOrElse ""

    def cc: Char = if ("tT" contains op(0)) op(1) else op(0)

    def indexed:   Boolean = index.nonEmpty || hasFlag('<')
    def isLiteral: Boolean = false
    def isLeading: Boolean = m.start(0) == 0
    def verify:    Boolean = goodFlags && goodIndex
    def accepts(arg: Tree): Option[Type]

    val allFlags = "-#+ 0,(<"
    def hasFlag(f: Char) = (flags getOrElse "") contains f
    def hasAnyFlag(fs: String) = fs exists (hasFlag)

    def badFlag(f: Char, msg: String) = {
      val i = flags map (_.indexOf(f)) filter (_ >= 0) getOrElse 0
      errorAtOffset(Flags, i, msg)
    }
    def groupPos(g: SpecGroup) = groupPosAt(g, 0)
    def groupPosAt(g: SpecGroup, i: Int) = pos withPoint (pos.point + m.start(g.id) + i)
    def errorAt(g: SpecGroup, msg: String) = c.error(groupPos(g), msg)
    def errorAtOffset(g: SpecGroup, i: Int, msg: String) = c.error(groupPosAt(g, i), msg)

    def noFlags = flags.isEmpty || falsely { errorAt(Flags, "flags not allowed") }
    def noWidth = width.isEmpty || falsely { errorAt(Width, "width not allowed") }
    def noPrecision = precision.isEmpty || falsely { errorAt(Precision, "precision not allowed") }
    def only_-(msg: String) = {
      val badFlags = (flags getOrElse "") filterNot { case '-' | '<' => true case _ => false }
      badFlags.isEmpty || falsely { badFlag(badFlags(0), s"Only '-' allowed for $msg") }
    }
    protected def okFlags: String = allFlags
    def goodFlags = {
      val badFlags = flags map (_ filterNot (okFlags contains _))
      for (bf <- badFlags; f <- bf) badFlag(f, s"Illegal flag '$f'")
      badFlags.getOrElse("").isEmpty
    }
    def goodIndex = {
      if (index.nonEmpty && hasFlag('<'))
        c.warning(groupPos(Index), "Argument index ignored if '<' flag is present")
      val okRange = index map (i => i > 0 && i <= argc) getOrElse true
      okRange || hasFlag('<') || falsely { errorAt(Index, "Argument index out of range") }
    }
    /** Pick the type of an arg to format from among the variants
     *  supported by a conversion.  This is the type of the temporary,
     *  so failure results in an erroneous assignment to the first variant.
     *  A more complete message would be nice.
     */
    def pickAcceptable(arg: Tree, variants: Type*): Option[Type] =
      variants find (arg.tpe <:< _) orElse (
        variants find (c.inferImplicitView(arg, arg.tpe, _) != EmptyTree)
      ) orElse Some(variants(0))
  }
  object Conversion {
    import SpecifierGroups.{ Spec, CC }
    def apply(m: Match, p: Position, n: Int): Option[Conversion] = {
      def badCC(msg: String) = {
        val dk = new ErrorXn(m, p)
        val at = if (dk.op.isEmpty) Spec else CC
        dk.errorAt(at, msg)
      }
      def cv(cc: Char) = cc match {
        case 'b' | 'B' | 'h' | 'H' | 's' | 'S' =>
          new GeneralXn(m, p, n)
        case 'c' | 'C' =>
          new CharacterXn(m, p, n)
        case 'd' | 'o' | 'x' | 'X' =>
          new IntegralXn(m, p, n)
        case 'e' | 'E' | 'f' | 'g' | 'G' | 'a' | 'A' =>
          new FloatingPointXn(m, p, n)
        case 't' | 'T' =>
          new DateTimeXn(m, p, n)
        case '%' | 'n' =>
          new LiteralXn(m, p, n)
        case _ =>
          badCC(s"illegal conversion character '$cc'")
          null
      }
      Option(m group CC.id) map (cc => cv(cc(0))) match {
        case Some(x) => Option(x) filter (_.verify)
        case None    =>
          badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp")
          None
      }
    }
    val literalHelp = "use %% for literal %, %n for newline"
  }
  class GeneralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    def accepts(arg: Tree) = cc match {
      case 's' | 'S' if hasFlag('#') => pickAcceptable(arg, tagOfFormattable.tpe)
      case 'b' | 'B' => if (arg.tpe <:< NullTpe) Some(NullTpe) else Some(BooleanTpe)
      case _         => Some(AnyTpe)
    }
    override protected def okFlags = cc match {
      case 's' | 'S' => "-#<"
      case _         => "-<"
    }
  }
  class LiteralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    import SpecifierGroups.Width
    override val isLiteral = true
    override def verify = op match {
      case "%" => super.verify && noPrecision && truly(width foreach (_ => c.warning(groupPos(Width), "width ignored on literal")))
      case "n" => noFlags && noWidth && noPrecision
    }
    override protected val okFlags = "-"
    def accepts(arg: Tree) = None
  }
  class CharacterXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    override def verify = super.verify && noPrecision && only_-("c conversion")
    def accepts(arg: Tree) = pickAcceptable(arg, CharTpe, ByteTpe, ShortTpe, IntTpe)
  }
  class IntegralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    override def verify = {
      def d_# = (cc == 'd' && hasFlag('#') &&
        truly { badFlag('#', "# not allowed for d conversion") }
      )
      def x_comma = (cc != 'd' && hasFlag(',') &&
        truly { badFlag(',', "',' only allowed for d conversion of integral types") }
      )
      super.verify && noPrecision && !d_# && !x_comma
    }
    override def accepts(arg: Tree) = {
      def isBigInt = arg.tpe <:< tagOfBigInt.tpe
      val maybeOK = "+ ("
      def bad_+ = cond(cc) {
        case 'o' | 'x' | 'X' if hasAnyFlag(maybeOK) && !isBigInt =>
          maybeOK filter hasFlag foreach (badf =>
            badFlag(badf, s"only use '$badf' for BigInt conversions to o, x, X"))
          true
      }
      if (bad_+) None else pickAcceptable(arg, IntTpe, LongTpe, ByteTpe, ShortTpe, tagOfBigInt.tpe)
    }
  }
  class FloatingPointXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    override def verify = super.verify && (cc match {
      case 'a' | 'A' =>
        val badFlags = ",(" filter hasFlag
        noPrecision && badFlags.isEmpty || falsely {
          badFlags foreach (badf => badFlag(badf, s"'$badf' not allowed for a, A"))
        }
      case _ => true
    })
    def accepts(arg: Tree) = pickAcceptable(arg, DoubleTpe, FloatTpe, tagOfBigDecimal.tpe)
  }
  class DateTimeXn(val m: Match, val pos: Position, val argc: Int) extends Conversion {
    import SpecifierGroups.CC
    def hasCC = (op.length == 2 ||
      falsely { errorAt(CC, "Date/time conversion must have two characters") })
    def goodCC = ("HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc" contains cc) ||
      falsely { errorAtOffset(CC, 1, s"'$cc' doesn't seem to be a date or time conversion") }
    override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions")
    def accepts(arg: Tree) = pickAcceptable(arg, LongTpe, tagOfCalendar.tpe, tagOfDate.tpe)
  }
  class ErrorXn(val m: Match, val pos: Position) extends Conversion {
    val argc = 0
    override def verify = false
    def accepts(arg: Tree) = None
  }
}