summaryrefslogtreecommitdiff
path: root/sources/scala/tools/nsc/typechecker/ConstantFolder.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sources/scala/tools/nsc/typechecker/ConstantFolder.scala')
-rwxr-xr-xsources/scala/tools/nsc/typechecker/ConstantFolder.scala275
1 files changed, 111 insertions, 164 deletions
diff --git a/sources/scala/tools/nsc/typechecker/ConstantFolder.scala b/sources/scala/tools/nsc/typechecker/ConstantFolder.scala
index 3a78775393..7883246424 100755
--- a/sources/scala/tools/nsc/typechecker/ConstantFolder.scala
+++ b/sources/scala/tools/nsc/typechecker/ConstantFolder.scala
@@ -11,187 +11,134 @@ abstract class ConstantFolder {
import global._;
import definitions._;
- private val NoValue = new Object();
-
/** If tree is a constant operation, replace with result. */
def apply(tree: Tree): Tree = fold(tree, tree match {
case Apply(Select(Literal(x), op), List(Literal(y))) => foldBinop(op, x, y)
case Select(Literal(x), op) => foldUnop(op, x)
- case _ => NoValue
+ case _ => null
});
/** If tree is a constant value that can be converted to type `pt', perform the conversion */
def apply(tree: Tree, pt: Type): Tree = fold(tree, tree.tpe match {
- case ConstantType(base, value) => foldTyped(value, pt)
- case _ => NoValue
+ case ConstantType(x) => x convertTo pt
+ case _ => null
});
- private def fold(tree: Tree, value: Any): Tree =
- if (value != NoValue && value != ()) tree setType ConstantType(literalType(value), value)
- else tree;
+ private def fold(tree: Tree, x: Constant): Tree =
+ if (x != null && x.tag != UnitTag) tree setType ConstantType(x) else tree;
- private def foldUnop(op: Name, value: Any): Any = Pair(op, value) match {
- case Pair(nme.ZNOT, x: boolean) => !x
+ private def foldUnop(op: Name, x: Constant): Constant = Pair(op, x.tag) match {
+ case Pair(nme.ZNOT, BooleanTag) => Constant(!x.booleanValue)
- case Pair(nme.NOT , x: int ) => ~x
- case Pair(nme.NOT , x: long ) => ~x
+ case Pair(nme.NOT , IntTag ) => Constant(~x.intValue)
+ case Pair(nme.NOT , LongTag ) => Constant(~x.longValue)
- case Pair(nme.ADD , x: int ) => +x
- case Pair(nme.ADD , x: long ) => +x
- case Pair(nme.ADD , x: float ) => +x
- case Pair(nme.ADD , x: double ) => +x
+ case Pair(nme.ADD , IntTag ) => Constant(+x.intValue)
+ case Pair(nme.ADD , LongTag ) => Constant(+x.longValue)
+ case Pair(nme.ADD , FloatTag ) => Constant(+x.floatValue)
+ case Pair(nme.ADD , DoubleTag ) => Constant(+x.doubleValue)
- case Pair(nme.SUB , x: int ) => -x
- case Pair(nme.SUB , x: long ) => -x
- case Pair(nme.SUB , x: float ) => -x
- case Pair(nme.SUB , x: double ) => -x
+ case Pair(nme.SUB , IntTag ) => Constant(-x.intValue)
+ case Pair(nme.SUB , LongTag ) => Constant(-x.longValue)
+ case Pair(nme.SUB , FloatTag ) => Constant(-x.floatValue)
+ case Pair(nme.SUB , DoubleTag ) => Constant(-x.doubleValue)
- case _ => NoValue
+ case _ => null
}
- private def foldBinop(op: Name, lvalue: Any, rvalue: Any): Any = Triple(op, lvalue, rvalue) match {
- case Triple(nme.ZOR , x: boolean, y: boolean) => x | y
- case Triple(nme.OR , x: boolean, y: boolean) => x | y
- case Triple(nme.OR , x: int , y: int ) => x | y
- case Triple(nme.OR , x: long , y: long ) => x | y
-
- case Triple(nme.XOR , x: boolean, y: boolean) => x ^ y
- case Triple(nme.XOR , x: int , y: int ) => x ^ y
- case Triple(nme.XOR , x: long , y: long ) => x ^ y
-
- case Triple(nme.ZAND, x: boolean, y: boolean) => x & y
- case Triple(nme.AND , x: boolean, y: boolean) => x & y
- case Triple(nme.AND , x: int , y: int ) => x & y
- case Triple(nme.AND , x: long , y: long ) => x & y
-
- case Triple(nme.LSL , x: int , y: int ) => x << y
- case Triple(nme.LSL , x: long , y: int ) => x << y
- case Triple(nme.LSL , x: long , y: long ) => x << y
-
- case Triple(nme.LSR , x: int , y: int ) => x >>> y
- case Triple(nme.LSR , x: long , y: int ) => x >>> y
- case Triple(nme.LSR , x: long , y: long ) => x >>> y
-
- case Triple(nme.ASR , x: int , y: int ) => x >> y
- case Triple(nme.ASR , x: long , y: int ) => x >> y
- case Triple(nme.ASR , x: long , y: long ) => x >> y
-
- case Triple(nme.EQ , x: boolean, y: boolean) => x == y
- case Triple(nme.EQ , x: int , y: int ) => x == y
- case Triple(nme.EQ , x: long , y: long ) => x == y
- case Triple(nme.EQ , x: float , y: float ) => x == y
- case Triple(nme.EQ , x: double , y: double ) => x == y
-
- case Triple(nme.NE , x: boolean, y: boolean) => x != y
- case Triple(nme.NE , x: int , y: int ) => x != y
- case Triple(nme.NE , x: long , y: long ) => x != y
- case Triple(nme.NE , x: float , y: float ) => x != y
- case Triple(nme.NE , x: double , y: double ) => x != y
-
- case Triple(nme.LT , x: int , y: int ) => x < y
- case Triple(nme.LT , x: long , y: long ) => x < y
- case Triple(nme.LT , x: float , y: float ) => x < y
- case Triple(nme.LT , x: double , y: double ) => x < y
-
- case Triple(nme.GT , x: int , y: int ) => x > y
- case Triple(nme.GT , x: long , y: long ) => x > y
- case Triple(nme.GT , x: float , y: float ) => x > y
- case Triple(nme.GT , x: double , y: double ) => x > y
-
- case Triple(nme.LE , x: int , y: int ) => x <= y
- case Triple(nme.LE , x: long , y: long ) => x <= y
- case Triple(nme.LE , x: float , y: float ) => x <= y
- case Triple(nme.LE , x: double , y: double ) => x <= y
-
- case Triple(nme.GE , x: int , y: int ) => x >= y
- case Triple(nme.GE , x: long , y: long ) => x >= y
- case Triple(nme.GE , x: float , y: float ) => x >= y
- case Triple(nme.GE , x: double , y: double ) => x >= y
-
- case Triple(nme.ADD , x: int , y: int ) => x + y
- case Triple(nme.ADD , x: long , y: long ) => x + y
- case Triple(nme.ADD , x: float , y: float ) => x + y
- case Triple(nme.ADD , x: double , y: double ) => x + y
- case Triple(nme.ADD , x: String , y: String ) => x + y
-
- case Triple(nme.SUB , x: int , y: int ) => x - y
- case Triple(nme.SUB , x: long , y: long ) => x - y
- case Triple(nme.SUB , x: float , y: float ) => x - y
- case Triple(nme.SUB , x: double , y: double ) => x - y
-
- case Triple(nme.MUL , x: int , y: int ) => x * y
- case Triple(nme.MUL , x: long , y: long ) => x * y
- case Triple(nme.MUL , x: float , y: float ) => x * y
- case Triple(nme.MUL , x: double , y: double ) => x * y
-
- case Triple(nme.DIV , x: int , y: int ) => x / y
- case Triple(nme.DIV , x: long , y: long ) => x / y
- case Triple(nme.DIV , x: float , y: float ) => x / y
- case Triple(nme.DIV , x: double , y: double ) => x / y
-
- case Triple(nme.MOD , x: int , y: int ) => x % y
- case Triple(nme.MOD , x: long , y: long ) => x % y
- case Triple(nme.MOD , x: float , y: float ) => x % y
- case Triple(nme.MOD , x: double , y: double ) => x % y
-
- case _ => NoValue
- }
-
- /** Widen constant value to conform to given type */
- private def foldTyped(value: Any, pt: Type): Any = {
- val target = pt.symbol;
- value match {
- case x: byte =>
- if (target == ShortClass) x.asInstanceOf[short]
- else if (target == CharClass) x.asInstanceOf[char]
- else if (target == IntClass) x.asInstanceOf[int]
- else if (target == LongClass) x.asInstanceOf[long]
- else if (target == FloatClass) x.asInstanceOf[float]
- else if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x: short =>
- if (target == IntClass) x.asInstanceOf[int]
- else if (target == LongClass) x.asInstanceOf[long]
- else if (target == FloatClass) x.asInstanceOf[float]
- else if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x: char =>
- if (target == IntClass) x.asInstanceOf[int]
- else if (target == LongClass) x.asInstanceOf[long]
- else if (target == FloatClass) x.asInstanceOf[float]
- else if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x: int =>
- if (target == ByteClass && -128 <= x && x <= 127) x.asInstanceOf[byte]
- else if (target == ShortClass && -32768 <= x && x <= 32767) x.asInstanceOf[short]
- else if (target == CharClass && 0 <= x && x <= 65535) x.asInstanceOf[char]
- else if (target == LongClass) x.asInstanceOf[long]
- else if (target == FloatClass) x.asInstanceOf[float]
- else if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x: long =>
- if (target == FloatClass) x.asInstanceOf[float]
- else if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x: float =>
- if (target == DoubleClass) x.asInstanceOf[double]
- else NoValue
- case x =>
- NoValue
+ private def foldBinop(op: Name, x: Constant, y: Constant): Constant = {
+ val optag = if (x.tag > y.tag) x.tag else y.tag;
+ optag match {
+ case BooleanTag =>
+ op match {
+ case nme.ZOR => Constant(x.booleanValue | y.booleanValue)
+ case nme.OR => Constant(x.booleanValue | y.booleanValue)
+ case nme.XOR => Constant(x.booleanValue ^ y.booleanValue)
+ case nme.ZAND => Constant(x.booleanValue & y.booleanValue)
+ case nme.AND => Constant(x.booleanValue & y.booleanValue)
+ case nme.EQ => Constant(x.booleanValue == y.booleanValue)
+ case nme.NE => Constant(x.booleanValue != y.booleanValue)
+ case _ => null
+ }
+ case ByteTag | ShortTag | LongTag | IntTag =>
+ op match {
+ case nme.OR => Constant(x.intValue | y.intValue)
+ case nme.XOR => Constant(x.intValue ^ y.intValue)
+ case nme.AND => Constant(x.intValue & y.intValue)
+ case nme.LSL => Constant(x.intValue << y.intValue)
+ case nme.LSR => Constant(x.intValue >>> y.intValue)
+ case nme.ASR => Constant(x.intValue >> y.intValue)
+ case nme.EQ => Constant(x.intValue == y.intValue)
+ case nme.NE => Constant(x.intValue != y.intValue)
+ case nme.LT => Constant(x.intValue < y.intValue)
+ case nme.GT => Constant(x.intValue > y.intValue)
+ case nme.LE => Constant(x.intValue <= y.intValue)
+ case nme.GE => Constant(x.intValue >= y.intValue)
+ case nme.ADD => Constant(x.intValue + y.intValue)
+ case nme.SUB => Constant(x.intValue - y.intValue)
+ case nme.MUL => Constant(x.intValue * y.intValue)
+ case nme.DIV => Constant(x.intValue / y.intValue)
+ case nme.MOD => Constant(x.intValue % y.intValue)
+ case _ => null
+ }
+ case LongTag =>
+ op match {
+ case nme.OR => Constant(x.longValue | y.longValue)
+ case nme.XOR => Constant(x.longValue ^ y.longValue)
+ case nme.AND => Constant(x.longValue & y.longValue)
+ case nme.LSL => Constant(x.longValue << y.longValue)
+ case nme.LSR => Constant(x.longValue >>> y.longValue)
+ case nme.ASR => Constant(x.longValue >> y.longValue)
+ case nme.EQ => Constant(x.longValue == y.longValue)
+ case nme.NE => Constant(x.longValue != y.longValue)
+ case nme.LT => Constant(x.longValue < y.longValue)
+ case nme.GT => Constant(x.longValue > y.longValue)
+ case nme.LE => Constant(x.longValue <= y.longValue)
+ case nme.GE => Constant(x.longValue >= y.longValue)
+ case nme.ADD => Constant(x.longValue + y.longValue)
+ case nme.SUB => Constant(x.longValue - y.longValue)
+ case nme.MUL => Constant(x.longValue * y.longValue)
+ case nme.DIV => Constant(x.longValue / y.longValue)
+ case nme.MOD => Constant(x.longValue % y.longValue)
+ case _ => null
+ }
+ case FloatTag =>
+ op match {
+ case nme.EQ => Constant(x.floatValue == y.floatValue)
+ case nme.NE => Constant(x.floatValue != y.floatValue)
+ case nme.LT => Constant(x.floatValue < y.floatValue)
+ case nme.GT => Constant(x.floatValue > y.floatValue)
+ case nme.LE => Constant(x.floatValue <= y.floatValue)
+ case nme.GE => Constant(x.floatValue >= y.floatValue)
+ case nme.ADD => Constant(x.floatValue + y.floatValue)
+ case nme.SUB => Constant(x.floatValue - y.floatValue)
+ case nme.MUL => Constant(x.floatValue * y.floatValue)
+ case nme.DIV => Constant(x.floatValue / y.floatValue)
+ case nme.MOD => Constant(x.floatValue % y.floatValue)
+ case _ => null
+ }
+ case DoubleTag =>
+ op match {
+ case nme.EQ => Constant(x.doubleValue == y.doubleValue)
+ case nme.NE => Constant(x.doubleValue != y.doubleValue)
+ case nme.LT => Constant(x.doubleValue < y.doubleValue)
+ case nme.GT => Constant(x.doubleValue > y.doubleValue)
+ case nme.LE => Constant(x.doubleValue <= y.doubleValue)
+ case nme.GE => Constant(x.doubleValue >= y.doubleValue)
+ case nme.ADD => Constant(x.doubleValue + y.doubleValue)
+ case nme.SUB => Constant(x.doubleValue - y.doubleValue)
+ case nme.MUL => Constant(x.doubleValue * y.doubleValue)
+ case nme.DIV => Constant(x.doubleValue / y.doubleValue)
+ case nme.MOD => Constant(x.doubleValue % y.doubleValue)
+ case _ => null
+ }
+ case StringTag =>
+ op match {
+ case nme.ADD => Constant(x.stringValue + y.stringValue)
+ case _ => null
+ }
+ case _ =>
+ null
}
}
-
- def literalType(value: Any): Type =
- if (value.isInstanceOf[unit]) UnitClass.tpe
- else if (value.isInstanceOf[boolean]) BooleanClass.tpe
- else if (value.isInstanceOf[byte]) ByteClass.tpe
- else if (value.isInstanceOf[short]) ShortClass.tpe
- else if (value.isInstanceOf[char]) CharClass.tpe
- else if (value.isInstanceOf[int]) IntClass.tpe
- else if (value.isInstanceOf[long]) LongClass.tpe
- else if (value.isInstanceOf[float]) FloatClass.tpe
- else if (value.isInstanceOf[double]) DoubleClass.tpe
- else if (value.isInstanceOf[String]) StringClass.tpe
- else if (value == null) AllRefClass.tpe
- else throw new FatalError("unexpected literal value: " + value);
}