diff options
author | Eugene Burmako <xeno.by@gmail.com> | 2014-02-09 12:15:13 +0100 |
---|---|---|
committer | Eugene Burmako <xeno.by@gmail.com> | 2014-02-09 12:15:13 +0100 |
commit | 08e51dfec50842253afb87cc5ae3c7400dc18ced (patch) | |
tree | 7655825529e7690195bd92e13494d44c6d4bf223 /src | |
parent | 21a765feb0efeeecd671ede637a12f5088ac8257 (diff) | |
parent | ab7a8bcdb50128bfe2ac6d9d04e560593a3131ef (diff) | |
download | scala-08e51dfec50842253afb87cc5ae3c7400dc18ced.tar.gz scala-08e51dfec50842253afb87cc5ae3c7400dc18ced.tar.bz2 scala-08e51dfec50842253afb87cc5ae3c7400dc18ced.zip |
Merge pull request #3420 from som-snytt/issue/8092-f-parsing
SI-8092 More verify for f-interpolator
Diffstat (limited to 'src')
3 files changed, 332 insertions, 174 deletions
diff --git a/src/compiler/scala/tools/reflect/FastTrack.scala b/src/compiler/scala/tools/reflect/FastTrack.scala index bb0bbd79a3..8630ecf69e 100644 --- a/src/compiler/scala/tools/reflect/FastTrack.scala +++ b/src/compiler/scala/tools/reflect/FastTrack.scala @@ -20,8 +20,8 @@ trait FastTrack { private implicit def context2taggers(c0: MacroContext): Taggers { val c: c0.type } = new { val c: c0.type = c0 } with Taggers - private implicit def context2macroimplementations(c0: MacroContext): MacroImplementations { val c: c0.type } = - new { val c: c0.type = c0 } with MacroImplementations + private implicit def context2macroimplementations(c0: MacroContext): FormatInterpolator { val c: c0.type } = + new { val c: c0.type = c0 } with FormatInterpolator private implicit def context2quasiquote(c0: MacroContext): QuasiquoteImpls { val c: c0.type } = new { val c: c0.type = c0 } with QuasiquoteImpls private def makeBlackbox(sym: Symbol)(pf: PartialFunction[Applied, MacroContext => Tree]) = @@ -48,7 +48,7 @@ trait FastTrack { makeBlackbox( materializeWeakTypeTag) { case Applied(_, ttag :: Nil, (u :: _) :: _) => _.materializeTypeTag(u, EmptyTree, ttag.tpe, concrete = false) }, makeBlackbox( materializeTypeTag) { case Applied(_, ttag :: Nil, (u :: _) :: _) => _.materializeTypeTag(u, EmptyTree, ttag.tpe, concrete = true) }, makeBlackbox( ApiUniverseReify) { case Applied(_, ttag :: Nil, (expr :: _) :: _) => c => c.materializeExpr(c.prefix.tree, EmptyTree, expr) }, - makeBlackbox( StringContext_f) { case Applied(Select(Apply(_, ps), _), _, args) => c => c.macro_StringInterpolation_f(ps, args.flatten, c.expandee.pos) }, + makeBlackbox( StringContext_f) { case _ => _.interpolate }, makeBlackbox(ReflectRuntimeCurrentMirror) { case _ => c => currentMirror(c).tree }, makeWhitebox( QuasiquoteClass_api_apply) { case _ => _.expandQuasiquote }, makeWhitebox(QuasiquoteClass_api_unapply) { case _ => _.expandQuasiquote } diff --git a/src/compiler/scala/tools/reflect/FormatInterpolator.scala b/src/compiler/scala/tools/reflect/FormatInterpolator.scala new file mode 100644 index 0000000000..d5e674ebae --- /dev/null +++ b/src/compiler/scala/tools/reflect/FormatInterpolator.scala @@ -0,0 +1,329 @@ +package scala.tools.reflect + +import scala.reflect.macros.runtime.Context +import scala.collection.mutable.{ ListBuffer, Stack } +import scala.reflect.internal.util.Position +import scala.PartialFunction.cond +import scala.util.matching.Regex.Match + +import java.util.{ Formatter, Formattable, IllegalFormatException } + +abstract class FormatInterpolator { + val c: Context + val global: c.universe.type = c.universe + + import c.universe.{ Match => _, _ } + import definitions._ + import treeInfo.Applied + + @inline private def truly(body: => Unit): Boolean = { body ; true } + @inline private def falsely(body: => Unit): Boolean = { body ; false } + + private def fail(msg: String) = c.abort(c.enclosingPosition, msg) + private def bail(msg: String) = global.abort(msg) + + def interpolate: Tree = c.macroApplication match { + //case q"$_(..$parts).f(..$args)" => + case Applied(Select(Apply(_, parts), _), _, argss) => + val args = argss.flatten + def badlyInvoked = (parts.length != args.length + 1) && truly { + def because(s: String) = s"too $s arguments for interpolated string" + val (p, msg) = + if (parts.length == 0) (c.prefix.tree.pos, "there are no parts") + else if (args.length + 1 < parts.length) + (if (args.isEmpty) c.enclosingPosition else args.last.pos, because("few")) + else (args(parts.length-1).pos, because("many")) + c.abort(p, msg) + } + if (badlyInvoked) c.macroApplication else interpolated(parts, args) + case other => + bail(s"Unexpected application ${showRaw(other)}") + other + } + + /** Every part except the first must begin with a conversion for + * the arg that preceded it. If the conversion is missing, "%s" + * is inserted. + * + * In any other position, the only permissible conversions are + * the literals (%% and %n) or an index reference (%1$ or %<). + * + * A conversion specifier has the form: + * + * [index$][flags][width][.precision]conversion + * + * 1) "...${smth}" => okay, equivalent to "...${smth}%s" + * 2) "...${smth}blahblah" => okay, equivalent to "...${smth}%sblahblah" + * 3) "...${smth}%" => error + * 4) "...${smth}%n" => okay, equivalent to "...${smth}%s%n" + * 5) "...${smth}%%" => okay, equivalent to "...${smth}%s%%" + * 6) "...${smth}[%legalJavaConversion]" => okay* + * 7) "...${smth}[%illegalJavaConversion]" => error + * *Legal according to [[http://docs.oracle.com/javase/1.5.0/docs/api/java/util/Formatter.html]] + */ + def interpolated(parts: List[Tree], args: List[Tree]) = { + val fstring = new StringBuilder + val evals = ListBuffer[ValDef]() + val ids = ListBuffer[Ident]() + val argStack = Stack(args: _*) + + // create a tmp val and add it to the ids passed to format + def defval(value: Tree, tpe: Type): Unit = { + val freshName = TermName(c.freshName("arg$")) + evals += ValDef(Modifiers(), freshName, TypeTree(tpe) setPos value.pos.focus, value) setPos value.pos + ids += Ident(freshName) + } + // Append the nth part to the string builder, possibly prepending an omitted %s first. + // Sanity-check the % fields in this part. + def copyPart(part: Tree, n: Int): Unit = { + import SpecifierGroups.{ Spec, Index } + val s0 = part match { + case Literal(Constant(x: String)) => x + case _ => throw new IllegalArgumentException("internal error: argument parts must be a list of string literals") + } + val s = StringContext.treatEscapes(s0) + val ms = fpat findAllMatchIn s + + def errorLeading(op: Conversion) = op.errorAt(Spec, s"conversions must follow a splice; ${Conversion.literalHelp}") + + def first = n == 0 + // a conversion for the arg is required + if (!first) { + val arg = argStack.pop() + def s_%() = { + fstring append "%s" + defval(arg, AnyTpe) + } + def accept(op: Conversion) = { + if (!op.isLeading) errorLeading(op) + op.accepts(arg) match { + case Some(tpe) => defval(arg, tpe) + case None => + } + } + if (ms.hasNext) { + Conversion(ms.next, part.pos, args.size) match { + case Some(op) if op.isLiteral => s_%() + case Some(op) if op.indexed => + if (op.index map (_ == n) getOrElse true) accept(op) + else { + // either some other arg num, or '<' + c.warning(op.groupPos(Index), "Index is not this arg") + s_%() + } + case Some(op) => accept(op) + case None => + } + } else s_%() + } + // any remaining conversions must be either literals or indexed + while (ms.hasNext) { + Conversion(ms.next, part.pos, args.size) match { + case Some(op) if first && op.hasFlag('<') => op.badFlag('<', "No last arg") + case Some(op) if op.isLiteral || op.indexed => // OK + case Some(op) => errorLeading(op) + case None => + } + } + fstring append s + } + + parts.zipWithIndex foreach { + case (part, n) => copyPart(part, n) + } + + //q"{..$evals; ${fstring.toString}.format(..$ids)}" + locally { + val expr = + Apply( + Select( + Literal(Constant(fstring.toString)), + newTermName("format")), + ids.toList + ) + val p = c.macroApplication.pos + Block(evals.toList, atPos(p.focus)(expr)) setPos p.makeTransparent + } + } + + val fpat = """%(?:(\d+)\$)?([-#+ 0,(\<]+)?(\d+)?(\.\d+)?([tT]?[%a-zA-Z])?""".r + object SpecifierGroups extends Enumeration { val Spec, Index, Flags, Width, Precision, CC = Value } + + val stdContextTags = new { val tc: c.type = c } with StdContextTags + import stdContextTags._ + val tagOfFormattable = typeTag[Formattable] + + /** A conversion specifier matched by `m` in the string part at `pos`, + * with `argc` arguments to interpolate. + */ + sealed trait Conversion { + def m: Match + def pos: Position + def argc: Int + + import SpecifierGroups.{ Value => SpecGroup, _ } + private def maybeStr(g: SpecGroup) = Option(m group g.id) + private def maybeInt(g: SpecGroup) = maybeStr(g) map (_.toInt) + val index: Option[Int] = maybeInt(Index) + val flags: Option[String] = maybeStr(Flags) + val width: Option[Int] = maybeInt(Width) + val precision: Option[Int] = maybeStr(Precision) map (_.drop(1).toInt) + val op: String = maybeStr(CC) getOrElse "" + + def cc: Char = if ("tT" contains op(0)) op(1) else op(0) + + def indexed: Boolean = index.nonEmpty || hasFlag('<') + def isLiteral: Boolean = false + def isLeading: Boolean = m.start(0) == 0 + def verify: Boolean = goodFlags && goodIndex + def accepts(arg: Tree): Option[Type] + + val allFlags = "-#+ 0,(<" + def hasFlag(f: Char) = (flags getOrElse "") contains f + def hasAnyFlag(fs: String) = fs exists (hasFlag) + + def badFlag(f: Char, msg: String) = { + val i = flags map (_.indexOf(f)) filter (_ >= 0) getOrElse 0 + errorAtOffset(Flags, i, msg) + } + def groupPos(g: SpecGroup) = groupPosAt(g, 0) + def groupPosAt(g: SpecGroup, i: Int) = pos withPoint (pos.point + m.start(g.id) + i) + def errorAt(g: SpecGroup, msg: String) = c.error(groupPos(g), msg) + def errorAtOffset(g: SpecGroup, i: Int, msg: String) = c.error(groupPosAt(g, i), msg) + + def noFlags = flags.isEmpty || falsely { errorAt(Flags, "flags not allowed") } + def noWidth = width.isEmpty || falsely { errorAt(Width, "width not allowed") } + def noPrecision = precision.isEmpty || falsely { errorAt(Precision, "precision not allowed") } + def only_-(msg: String) = { + val badFlags = (flags getOrElse "") filterNot { case '-' | '<' => true case _ => false } + badFlags.isEmpty || falsely { badFlag(badFlags(0), s"Only '-' allowed for $msg") } + } + protected def okFlags: String = allFlags + def goodFlags = { + val badFlags = flags map (_ filterNot (okFlags contains _)) + for (bf <- badFlags; f <- bf) badFlag(f, s"Illegal flag '$f'") + badFlags.getOrElse("").isEmpty + } + def goodIndex = { + if (index.nonEmpty && hasFlag('<')) + c.warning(groupPos(Index), "Argument index ignored if '<' flag is present") + val okRange = index map (i => i > 0 && i <= argc) getOrElse true + okRange || hasFlag('<') || falsely { errorAt(Index, "Argument index out of range") } + } + /** Pick the type of an arg to format from among the variants + * supported by a conversion. This is the type of the temporary, + * so failure results in an erroneous assignment to the first variant. + * A more complete message would be nice. + */ + def pickAcceptable(arg: Tree, variants: Type*): Option[Type] = + variants find (arg.tpe <:< _) orElse ( + variants find (c.inferImplicitView(arg, arg.tpe, _) != EmptyTree) + ) orElse Some(variants(0)) + } + object Conversion { + import SpecifierGroups.{ Spec, CC, Width } + def apply(m: Match, p: Position, n: Int): Option[Conversion] = { + def badCC(msg: String) = { + val dk = new ErrorXn(m, p) + val at = if (dk.op.isEmpty) Spec else CC + dk.errorAt(at, msg) + } + def cv(cc: Char) = cc match { + case 'b' | 'B' | 'h' | 'H' | 's' | 'S' => + new GeneralXn(m, p, n) + case 'c' | 'C' => + new CharacterXn(m, p, n) + case 'd' | 'o' | 'x' | 'X' => + new IntegralXn(m, p, n) + case 'e' | 'E' | 'f' | 'g' | 'G' | 'a' | 'A' => + new FloatingPointXn(m, p, n) + case 't' | 'T' => + new DateTimeXn(m, p, n) + case '%' | 'n' => + new LiteralXn(m, p, n) + case _ => + badCC(s"illegal conversion character '$cc'") + null + } + Option(m group CC.id) map (cc => cv(cc(0))) match { + case Some(x) => Option(x) filter (_.verify) + case None => + badCC(s"Missing conversion operator in '${m.matched}'; $literalHelp") + None + } + } + val literalHelp = "use %% for literal %, %n for newline" + } + class GeneralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + def accepts(arg: Tree) = cc match { + case 's' | 'S' if hasFlag('#') => pickAcceptable(arg, tagOfFormattable.tpe) + case 'b' | 'B' => if (arg.tpe <:< NullTpe) Some(NullTpe) else Some(BooleanTpe) + case _ => Some(AnyTpe) + } + override protected def okFlags = cc match { + case 's' | 'S' => "-#<" + case _ => "-<" + } + } + class LiteralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + import SpecifierGroups.Width + override val isLiteral = true + override def verify = op match { + case "%" => super.verify && noPrecision && truly(width foreach (_ => c.warning(groupPos(Width), "width ignored on literal"))) + case "n" => noFlags && noWidth && noPrecision + } + override protected val okFlags = "-" + def accepts(arg: Tree) = None + } + class CharacterXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + override def verify = super.verify && noPrecision && only_-("c conversion") + def accepts(arg: Tree) = pickAcceptable(arg, CharTpe, ByteTpe, ShortTpe, IntTpe) + } + class IntegralXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + override def verify = { + def d_# = (cc == 'd' && hasFlag('#') && + truly { badFlag('#', "# not allowed for d conversion") } + ) + def x_comma = (cc != 'd' && hasFlag(',') && + truly { badFlag(',', "',' only allowed for d conversion of integral types") } + ) + super.verify && noPrecision && !d_# && !x_comma + } + override def accepts(arg: Tree) = { + def isBigInt = arg.tpe <:< tagOfBigInt.tpe + val maybeOK = "+ (" + def bad_+ = cond(cc) { + case 'o' | 'x' | 'X' if hasAnyFlag(maybeOK) && !isBigInt => + maybeOK filter hasFlag foreach (badf => + badFlag(badf, s"only use '$badf' for BigInt conversions to o, x, X")) + true + } + if (bad_+) None else pickAcceptable(arg, IntTpe, LongTpe, ByteTpe, ShortTpe, tagOfBigInt.tpe) + } + } + class FloatingPointXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + override def verify = super.verify && (cc match { + case 'a' | 'A' => + val badFlags = ",(" filter hasFlag + noPrecision && badFlags.isEmpty || falsely { + badFlags foreach (badf => badFlag(badf, s"'$badf' not allowed for a, A")) + } + case _ => true + }) + def accepts(arg: Tree) = pickAcceptable(arg, DoubleTpe, FloatTpe, tagOfBigDecimal.tpe) + } + class DateTimeXn(val m: Match, val pos: Position, val argc: Int) extends Conversion { + import SpecifierGroups.CC + def hasCC = (op.length == 2 || + falsely { errorAt(CC, "Date/time conversion must have two characters") }) + def goodCC = ("HIklMSLNpzZsQBbhAaCYyjmdeRTrDFc" contains cc) || + falsely { errorAtOffset(CC, 1, s"'$cc' doesn't seem to be a date or time conversion") } + override def verify = super.verify && hasCC && goodCC && noPrecision && only_-("date/time conversions") + def accepts(arg: Tree) = pickAcceptable(arg, LongTpe, tagOfCalendar.tpe, tagOfDate.tpe) + } + class ErrorXn(val m: Match, val pos: Position) extends Conversion { + val argc = 0 + override def verify = false + def accepts(arg: Tree) = None + } +} diff --git a/src/compiler/scala/tools/reflect/MacroImplementations.scala b/src/compiler/scala/tools/reflect/MacroImplementations.scala deleted file mode 100644 index a9ed419b1e..0000000000 --- a/src/compiler/scala/tools/reflect/MacroImplementations.scala +++ /dev/null @@ -1,171 +0,0 @@ -package scala.tools.reflect - -import scala.reflect.macros.contexts.Context -import scala.collection.mutable.ListBuffer -import scala.collection.mutable.Stack -import scala.reflect.internal.util.Position - -abstract class MacroImplementations { - val c: Context - - import c.universe._ - import definitions._ - - def macro_StringInterpolation_f(parts: List[Tree], args: List[Tree], origApplyPos: c.universe.Position): Tree = { - // the parts all have the same position information (as the expression is generated by the compiler) - // the args have correct position information - - // the following conditions can only be violated if invoked directly - if (parts.length != args.length + 1) { - if(parts.length == 0) - c.abort(c.prefix.tree.pos, "too few parts") - else if(args.length + 1 < parts.length) - c.abort(if(args.length==0) c.enclosingPosition else args.last.pos, - "too few arguments for interpolated string") - else - c.abort(args(parts.length-1).pos, - "too many arguments for interpolated string") - } - - val pi = parts.iterator - val bldr = new java.lang.StringBuilder - val evals = ListBuffer[ValDef]() - val ids = ListBuffer[Ident]() - val argStack = Stack(args : _*) - - def defval(value: Tree, tpe: Type): Unit = { - val freshName = newTermName(c.freshName("arg$")) - evals += ValDef(Modifiers(), freshName, TypeTree(tpe) setPos value.pos.focus, value) setPos value.pos - ids += Ident(freshName) - } - - def isFlag(ch: Char): Boolean = { - ch match { - case '-' | '#' | '+' | ' ' | '0' | ',' | '(' => true - case _ => false - } - } - - def checkType(arg: Tree, variants: Type*): Option[Type] = { - variants.find(arg.tpe <:< _).orElse( - variants.find(c.inferImplicitView(arg, arg.tpe, _) != EmptyTree).orElse( - Some(variants(0)) - ) - ) - } - - val stdContextTags = new { val tc: c.type = c } with StdContextTags - import stdContextTags._ - - def conversionType(ch: Char, arg: Tree): Option[Type] = { - ch match { - case 'b' | 'B' => - if(arg.tpe <:< NullTpe) Some(NullTpe) else Some(BooleanTpe) - case 'h' | 'H' => - Some(AnyTpe) - case 's' | 'S' => - Some(AnyTpe) - case 'c' | 'C' => - checkType(arg, CharTpe, ByteTpe, ShortTpe, IntTpe) - case 'd' | 'o' | 'x' | 'X' => - checkType(arg, IntTpe, LongTpe, ByteTpe, ShortTpe, tagOfBigInt.tpe) - case 'e' | 'E' | 'g' | 'G' | 'f' | 'a' | 'A' => - checkType(arg, DoubleTpe, FloatTpe, tagOfBigDecimal.tpe) - case 't' | 'T' => - checkType(arg, LongTpe, tagOfCalendar.tpe, tagOfDate.tpe) - case _ => None - } - } - - def copyString(first: Boolean): Unit = { - val strTree = pi.next() - val rawStr = strTree match { - case Literal(Constant(str: String)) => str - case _ => throw new IllegalArgumentException("internal error: argument parts must be a list of string literals") - } - val str = StringContext.treatEscapes(rawStr) - val strLen = str.length - val strIsEmpty = strLen == 0 - def charAtIndexIs(idx: Int, ch: Char) = idx < strLen && str(idx) == ch - def isPercent(idx: Int) = charAtIndexIs(idx, '%') - def isConversion(idx: Int) = isPercent(idx) && !charAtIndexIs(idx + 1, 'n') && !charAtIndexIs(idx + 1, '%') - var idx = 0 - - def errorAtIndex(idx: Int, msg: String) = c.error(Position.offset(strTree.pos.source, strTree.pos.point + idx), msg) - def wrongConversionString(idx: Int) = errorAtIndex(idx, "wrong conversion string") - def illegalConversionCharacter(idx: Int) = errorAtIndex(idx, "illegal conversion character") - def nonEscapedPercent(idx: Int) = errorAtIndex(idx, - "conversions must follow a splice; use %% for literal %, %n for newline") - - // STEP 1: handle argument conversion - // 1) "...${smth}" => okay, equivalent to "...${smth}%s" - // 2) "...${smth}blahblah" => okay, equivalent to "...${smth}%sblahblah" - // 3) "...${smth}%" => error - // 4) "...${smth}%n" => okay, equivalent to "...${smth}%s%n" - // 5) "...${smth}%%" => okay, equivalent to "...${smth}%s%%" - // 6) "...${smth}[%legalJavaConversion]" => okay, according to http://docs.oracle.com/javase/1.5.0/docs/api/java/util/Formatter.html - // 7) "...${smth}[%illegalJavaConversion]" => error - if (!first) { - val arg = argStack.pop() - if (isConversion(0)) { - // PRE str is not empty and str(0) == '%' - // argument index parameter is not allowed, thus parse - // [flags][width][.precision]conversion - var pos = 1 - while (pos < strLen && isFlag(str charAt pos)) pos += 1 - while (pos < strLen && Character.isDigit(str charAt pos)) pos += 1 - if (pos < strLen && str.charAt(pos) == '.') { - pos += 1 - while (pos < strLen && Character.isDigit(str charAt pos)) pos += 1 - } - if (pos < strLen) { - conversionType(str charAt pos, arg) match { - case Some(tpe) => defval(arg, tpe) - case None => illegalConversionCharacter(pos) - } - } else { - wrongConversionString(pos - 1) - } - idx = 1 - } else { - bldr append "%s" - defval(arg, AnyTpe) - } - } - - // STEP 2: handle the rest of the text - // 1) %n tokens are left as is - // 2) %% tokens are left as is - // 3) other usages of percents are reported as errors - if (!strIsEmpty) { - while (idx < strLen) { - if (isPercent(idx)) { - if (isConversion(idx)) nonEscapedPercent(idx) - else idx += 1 // skip n and % in %n and %% - } - idx += 1 - } - bldr append (str take idx) - } - } - - copyString(first = true) - while (pi.hasNext) { - copyString(first = false) - } - - val fstring = bldr.toString -// val expr = c.reify(fstring.format((ids.map(id => Expr(id).eval)) : _*)) -// https://issues.scala-lang.org/browse/SI-5824, therefore - val expr = - Apply( - Select( - Literal(Constant(fstring)), - newTermName("format")), - List(ids: _* ) - ) - - Block(evals.toList, atPos(origApplyPos.focus)(expr)) setPos origApplyPos.makeTransparent - } - -} |