diff options
Diffstat (limited to 'sources/scala/tools/nsc/typechecker/ConstantFolder.scala')
-rwxr-xr-x | sources/scala/tools/nsc/typechecker/ConstantFolder.scala | 275 |
1 files changed, 111 insertions, 164 deletions
diff --git a/sources/scala/tools/nsc/typechecker/ConstantFolder.scala b/sources/scala/tools/nsc/typechecker/ConstantFolder.scala index 3a78775393..7883246424 100755 --- a/sources/scala/tools/nsc/typechecker/ConstantFolder.scala +++ b/sources/scala/tools/nsc/typechecker/ConstantFolder.scala @@ -11,187 +11,134 @@ abstract class ConstantFolder { import global._; import definitions._; - private val NoValue = new Object(); - /** If tree is a constant operation, replace with result. */ def apply(tree: Tree): Tree = fold(tree, tree match { case Apply(Select(Literal(x), op), List(Literal(y))) => foldBinop(op, x, y) case Select(Literal(x), op) => foldUnop(op, x) - case _ => NoValue + case _ => null }); /** If tree is a constant value that can be converted to type `pt', perform the conversion */ def apply(tree: Tree, pt: Type): Tree = fold(tree, tree.tpe match { - case ConstantType(base, value) => foldTyped(value, pt) - case _ => NoValue + case ConstantType(x) => x convertTo pt + case _ => null }); - private def fold(tree: Tree, value: Any): Tree = - if (value != NoValue && value != ()) tree setType ConstantType(literalType(value), value) - else tree; + private def fold(tree: Tree, x: Constant): Tree = + if (x != null && x.tag != UnitTag) tree setType ConstantType(x) else tree; - private def foldUnop(op: Name, value: Any): Any = Pair(op, value) match { - case Pair(nme.ZNOT, x: boolean) => !x + private def foldUnop(op: Name, x: Constant): Constant = Pair(op, x.tag) match { + case Pair(nme.ZNOT, BooleanTag) => Constant(!x.booleanValue) - case Pair(nme.NOT , x: int ) => ~x - case Pair(nme.NOT , x: long ) => ~x + case Pair(nme.NOT , IntTag ) => Constant(~x.intValue) + case Pair(nme.NOT , LongTag ) => Constant(~x.longValue) - case Pair(nme.ADD , x: int ) => +x - case Pair(nme.ADD , x: long ) => +x - case Pair(nme.ADD , x: float ) => +x - case Pair(nme.ADD , x: double ) => +x + case Pair(nme.ADD , IntTag ) => Constant(+x.intValue) + case Pair(nme.ADD , LongTag ) => Constant(+x.longValue) + case Pair(nme.ADD , FloatTag ) => Constant(+x.floatValue) + case Pair(nme.ADD , DoubleTag ) => Constant(+x.doubleValue) - case Pair(nme.SUB , x: int ) => -x - case Pair(nme.SUB , x: long ) => -x - case Pair(nme.SUB , x: float ) => -x - case Pair(nme.SUB , x: double ) => -x + case Pair(nme.SUB , IntTag ) => Constant(-x.intValue) + case Pair(nme.SUB , LongTag ) => Constant(-x.longValue) + case Pair(nme.SUB , FloatTag ) => Constant(-x.floatValue) + case Pair(nme.SUB , DoubleTag ) => Constant(-x.doubleValue) - case _ => NoValue + case _ => null } - private def foldBinop(op: Name, lvalue: Any, rvalue: Any): Any = Triple(op, lvalue, rvalue) match { - case Triple(nme.ZOR , x: boolean, y: boolean) => x | y - case Triple(nme.OR , x: boolean, y: boolean) => x | y - case Triple(nme.OR , x: int , y: int ) => x | y - case Triple(nme.OR , x: long , y: long ) => x | y - - case Triple(nme.XOR , x: boolean, y: boolean) => x ^ y - case Triple(nme.XOR , x: int , y: int ) => x ^ y - case Triple(nme.XOR , x: long , y: long ) => x ^ y - - case Triple(nme.ZAND, x: boolean, y: boolean) => x & y - case Triple(nme.AND , x: boolean, y: boolean) => x & y - case Triple(nme.AND , x: int , y: int ) => x & y - case Triple(nme.AND , x: long , y: long ) => x & y - - case Triple(nme.LSL , x: int , y: int ) => x << y - case Triple(nme.LSL , x: long , y: int ) => x << y - case Triple(nme.LSL , x: long , y: long ) => x << y - - case Triple(nme.LSR , x: int , y: int ) => x >>> y - case Triple(nme.LSR , x: long , y: int ) => x >>> y - case Triple(nme.LSR , x: long , y: long ) => x >>> y - - case Triple(nme.ASR , x: int , y: int ) => x >> y - case Triple(nme.ASR , x: long , y: int ) => x >> y - case Triple(nme.ASR , x: long , y: long ) => x >> y - - case Triple(nme.EQ , x: boolean, y: boolean) => x == y - case Triple(nme.EQ , x: int , y: int ) => x == y - case Triple(nme.EQ , x: long , y: long ) => x == y - case Triple(nme.EQ , x: float , y: float ) => x == y - case Triple(nme.EQ , x: double , y: double ) => x == y - - case Triple(nme.NE , x: boolean, y: boolean) => x != y - case Triple(nme.NE , x: int , y: int ) => x != y - case Triple(nme.NE , x: long , y: long ) => x != y - case Triple(nme.NE , x: float , y: float ) => x != y - case Triple(nme.NE , x: double , y: double ) => x != y - - case Triple(nme.LT , x: int , y: int ) => x < y - case Triple(nme.LT , x: long , y: long ) => x < y - case Triple(nme.LT , x: float , y: float ) => x < y - case Triple(nme.LT , x: double , y: double ) => x < y - - case Triple(nme.GT , x: int , y: int ) => x > y - case Triple(nme.GT , x: long , y: long ) => x > y - case Triple(nme.GT , x: float , y: float ) => x > y - case Triple(nme.GT , x: double , y: double ) => x > y - - case Triple(nme.LE , x: int , y: int ) => x <= y - case Triple(nme.LE , x: long , y: long ) => x <= y - case Triple(nme.LE , x: float , y: float ) => x <= y - case Triple(nme.LE , x: double , y: double ) => x <= y - - case Triple(nme.GE , x: int , y: int ) => x >= y - case Triple(nme.GE , x: long , y: long ) => x >= y - case Triple(nme.GE , x: float , y: float ) => x >= y - case Triple(nme.GE , x: double , y: double ) => x >= y - - case Triple(nme.ADD , x: int , y: int ) => x + y - case Triple(nme.ADD , x: long , y: long ) => x + y - case Triple(nme.ADD , x: float , y: float ) => x + y - case Triple(nme.ADD , x: double , y: double ) => x + y - case Triple(nme.ADD , x: String , y: String ) => x + y - - case Triple(nme.SUB , x: int , y: int ) => x - y - case Triple(nme.SUB , x: long , y: long ) => x - y - case Triple(nme.SUB , x: float , y: float ) => x - y - case Triple(nme.SUB , x: double , y: double ) => x - y - - case Triple(nme.MUL , x: int , y: int ) => x * y - case Triple(nme.MUL , x: long , y: long ) => x * y - case Triple(nme.MUL , x: float , y: float ) => x * y - case Triple(nme.MUL , x: double , y: double ) => x * y - - case Triple(nme.DIV , x: int , y: int ) => x / y - case Triple(nme.DIV , x: long , y: long ) => x / y - case Triple(nme.DIV , x: float , y: float ) => x / y - case Triple(nme.DIV , x: double , y: double ) => x / y - - case Triple(nme.MOD , x: int , y: int ) => x % y - case Triple(nme.MOD , x: long , y: long ) => x % y - case Triple(nme.MOD , x: float , y: float ) => x % y - case Triple(nme.MOD , x: double , y: double ) => x % y - - case _ => NoValue - } - - /** Widen constant value to conform to given type */ - private def foldTyped(value: Any, pt: Type): Any = { - val target = pt.symbol; - value match { - case x: byte => - if (target == ShortClass) x.asInstanceOf[short] - else if (target == CharClass) x.asInstanceOf[char] - else if (target == IntClass) x.asInstanceOf[int] - else if (target == LongClass) x.asInstanceOf[long] - else if (target == FloatClass) x.asInstanceOf[float] - else if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x: short => - if (target == IntClass) x.asInstanceOf[int] - else if (target == LongClass) x.asInstanceOf[long] - else if (target == FloatClass) x.asInstanceOf[float] - else if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x: char => - if (target == IntClass) x.asInstanceOf[int] - else if (target == LongClass) x.asInstanceOf[long] - else if (target == FloatClass) x.asInstanceOf[float] - else if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x: int => - if (target == ByteClass && -128 <= x && x <= 127) x.asInstanceOf[byte] - else if (target == ShortClass && -32768 <= x && x <= 32767) x.asInstanceOf[short] - else if (target == CharClass && 0 <= x && x <= 65535) x.asInstanceOf[char] - else if (target == LongClass) x.asInstanceOf[long] - else if (target == FloatClass) x.asInstanceOf[float] - else if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x: long => - if (target == FloatClass) x.asInstanceOf[float] - else if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x: float => - if (target == DoubleClass) x.asInstanceOf[double] - else NoValue - case x => - NoValue + private def foldBinop(op: Name, x: Constant, y: Constant): Constant = { + val optag = if (x.tag > y.tag) x.tag else y.tag; + optag match { + case BooleanTag => + op match { + case nme.ZOR => Constant(x.booleanValue | y.booleanValue) + case nme.OR => Constant(x.booleanValue | y.booleanValue) + case nme.XOR => Constant(x.booleanValue ^ y.booleanValue) + case nme.ZAND => Constant(x.booleanValue & y.booleanValue) + case nme.AND => Constant(x.booleanValue & y.booleanValue) + case nme.EQ => Constant(x.booleanValue == y.booleanValue) + case nme.NE => Constant(x.booleanValue != y.booleanValue) + case _ => null + } + case ByteTag | ShortTag | LongTag | IntTag => + op match { + case nme.OR => Constant(x.intValue | y.intValue) + case nme.XOR => Constant(x.intValue ^ y.intValue) + case nme.AND => Constant(x.intValue & y.intValue) + case nme.LSL => Constant(x.intValue << y.intValue) + case nme.LSR => Constant(x.intValue >>> y.intValue) + case nme.ASR => Constant(x.intValue >> y.intValue) + case nme.EQ => Constant(x.intValue == y.intValue) + case nme.NE => Constant(x.intValue != y.intValue) + case nme.LT => Constant(x.intValue < y.intValue) + case nme.GT => Constant(x.intValue > y.intValue) + case nme.LE => Constant(x.intValue <= y.intValue) + case nme.GE => Constant(x.intValue >= y.intValue) + case nme.ADD => Constant(x.intValue + y.intValue) + case nme.SUB => Constant(x.intValue - y.intValue) + case nme.MUL => Constant(x.intValue * y.intValue) + case nme.DIV => Constant(x.intValue / y.intValue) + case nme.MOD => Constant(x.intValue % y.intValue) + case _ => null + } + case LongTag => + op match { + case nme.OR => Constant(x.longValue | y.longValue) + case nme.XOR => Constant(x.longValue ^ y.longValue) + case nme.AND => Constant(x.longValue & y.longValue) + case nme.LSL => Constant(x.longValue << y.longValue) + case nme.LSR => Constant(x.longValue >>> y.longValue) + case nme.ASR => Constant(x.longValue >> y.longValue) + case nme.EQ => Constant(x.longValue == y.longValue) + case nme.NE => Constant(x.longValue != y.longValue) + case nme.LT => Constant(x.longValue < y.longValue) + case nme.GT => Constant(x.longValue > y.longValue) + case nme.LE => Constant(x.longValue <= y.longValue) + case nme.GE => Constant(x.longValue >= y.longValue) + case nme.ADD => Constant(x.longValue + y.longValue) + case nme.SUB => Constant(x.longValue - y.longValue) + case nme.MUL => Constant(x.longValue * y.longValue) + case nme.DIV => Constant(x.longValue / y.longValue) + case nme.MOD => Constant(x.longValue % y.longValue) + case _ => null + } + case FloatTag => + op match { + case nme.EQ => Constant(x.floatValue == y.floatValue) + case nme.NE => Constant(x.floatValue != y.floatValue) + case nme.LT => Constant(x.floatValue < y.floatValue) + case nme.GT => Constant(x.floatValue > y.floatValue) + case nme.LE => Constant(x.floatValue <= y.floatValue) + case nme.GE => Constant(x.floatValue >= y.floatValue) + case nme.ADD => Constant(x.floatValue + y.floatValue) + case nme.SUB => Constant(x.floatValue - y.floatValue) + case nme.MUL => Constant(x.floatValue * y.floatValue) + case nme.DIV => Constant(x.floatValue / y.floatValue) + case nme.MOD => Constant(x.floatValue % y.floatValue) + case _ => null + } + case DoubleTag => + op match { + case nme.EQ => Constant(x.doubleValue == y.doubleValue) + case nme.NE => Constant(x.doubleValue != y.doubleValue) + case nme.LT => Constant(x.doubleValue < y.doubleValue) + case nme.GT => Constant(x.doubleValue > y.doubleValue) + case nme.LE => Constant(x.doubleValue <= y.doubleValue) + case nme.GE => Constant(x.doubleValue >= y.doubleValue) + case nme.ADD => Constant(x.doubleValue + y.doubleValue) + case nme.SUB => Constant(x.doubleValue - y.doubleValue) + case nme.MUL => Constant(x.doubleValue * y.doubleValue) + case nme.DIV => Constant(x.doubleValue / y.doubleValue) + case nme.MOD => Constant(x.doubleValue % y.doubleValue) + case _ => null + } + case StringTag => + op match { + case nme.ADD => Constant(x.stringValue + y.stringValue) + case _ => null + } + case _ => + null } } - - def literalType(value: Any): Type = - if (value.isInstanceOf[unit]) UnitClass.tpe - else if (value.isInstanceOf[boolean]) BooleanClass.tpe - else if (value.isInstanceOf[byte]) ByteClass.tpe - else if (value.isInstanceOf[short]) ShortClass.tpe - else if (value.isInstanceOf[char]) CharClass.tpe - else if (value.isInstanceOf[int]) IntClass.tpe - else if (value.isInstanceOf[long]) LongClass.tpe - else if (value.isInstanceOf[float]) FloatClass.tpe - else if (value.isInstanceOf[double]) DoubleClass.tpe - else if (value.isInstanceOf[String]) StringClass.tpe - else if (value == null) AllRefClass.tpe - else throw new FatalError("unexpected literal value: " + value); } |