summaryrefslogblamecommitdiff
path: root/sources/scala/tools/nsc/typechecker/ConstantFolder.scala
blob: 3a78775393ec3bab858d16b49d6386b7662168f2 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16















                                                              




                                                                               

                                                                                               

                                                                     


                     


                                                                                             





















                                                                                                     
                                                          































































































































                                                                                         













                                                                    
 
/* NSC -- new scala compiler
 * Copyright 2005 LAMP/EPFL
 * @author  Martin Odersky
 */
// $Id$
package scala.tools.nsc.typechecker;

abstract class ConstantFolder {

  val global: Global;
  import global._;
  import definitions._;

  private val NoValue = new Object();

  /** If tree is a constant operation, replace with result. */
  def apply(tree: Tree): Tree = fold(tree, tree match {
    case Apply(Select(Literal(x), op), List(Literal(y))) => foldBinop(op, x, y)
    case Select(Literal(x), op) => foldUnop(op, x)
    case _ => NoValue
  });

  /** If tree is a constant value that can be converted to type `pt', perform the conversion */
  def apply(tree: Tree, pt: Type): Tree = fold(tree, tree.tpe match {
    case ConstantType(base, value) => foldTyped(value, pt)
    case _ => NoValue
  });

  private def fold(tree: Tree, value: Any): Tree =
    if (value != NoValue && value != ()) tree setType ConstantType(literalType(value), value)
    else tree;

  private def foldUnop(op: Name, value: Any): Any = Pair(op, value) match {
    case Pair(nme.ZNOT, x: boolean) => !x

    case Pair(nme.NOT , x: int    ) => ~x
    case Pair(nme.NOT , x: long   ) => ~x

    case Pair(nme.ADD , x: int    ) => +x
    case Pair(nme.ADD , x: long   ) => +x
    case Pair(nme.ADD , x: float  ) => +x
    case Pair(nme.ADD , x: double ) => +x

    case Pair(nme.SUB , x: int    ) => -x
    case Pair(nme.SUB , x: long   ) => -x
    case Pair(nme.SUB , x: float  ) => -x
    case Pair(nme.SUB , x: double ) => -x

    case _ => NoValue
  }

  private def foldBinop(op: Name, lvalue: Any, rvalue: Any): Any = Triple(op, lvalue, rvalue) match {
    case Triple(nme.ZOR , x: boolean, y: boolean) => x | y
    case Triple(nme.OR  , x: boolean, y: boolean) => x | y
    case Triple(nme.OR  , x: int    , y: int    ) => x | y
    case Triple(nme.OR  , x: long   , y: long   ) => x | y

    case Triple(nme.XOR , x: boolean, y: boolean) => x ^ y
    case Triple(nme.XOR , x: int    , y: int    ) => x ^ y
    case Triple(nme.XOR , x: long   , y: long   ) => x ^ y

    case Triple(nme.ZAND, x: boolean, y: boolean) => x & y
    case Triple(nme.AND , x: boolean, y: boolean) => x & y
    case Triple(nme.AND , x: int    , y: int    ) => x & y
    case Triple(nme.AND , x: long   , y: long   ) => x & y

    case Triple(nme.LSL , x: int    , y: int    ) => x << y
    case Triple(nme.LSL , x: long   , y: int    ) => x << y
    case Triple(nme.LSL , x: long   , y: long   ) => x << y

    case Triple(nme.LSR , x: int    , y: int    ) => x >>> y
    case Triple(nme.LSR , x: long   , y: int    ) => x >>> y
    case Triple(nme.LSR , x: long   , y: long   ) => x >>> y

    case Triple(nme.ASR , x: int    , y: int    ) => x >> y
    case Triple(nme.ASR , x: long   , y: int    ) => x >> y
    case Triple(nme.ASR , x: long   , y: long   ) => x >> y

    case Triple(nme.EQ  , x: boolean, y: boolean) => x == y
    case Triple(nme.EQ  , x: int    , y: int    ) => x == y
    case Triple(nme.EQ  , x: long   , y: long   ) => x == y
    case Triple(nme.EQ  , x: float  , y: float  ) => x == y
    case Triple(nme.EQ  , x: double , y: double ) => x == y

    case Triple(nme.NE  , x: boolean, y: boolean) => x != y
    case Triple(nme.NE  , x: int    , y: int    ) => x != y
    case Triple(nme.NE  , x: long   , y: long   ) => x != y
    case Triple(nme.NE  , x: float  , y: float  ) => x != y
    case Triple(nme.NE  , x: double , y: double ) => x != y

    case Triple(nme.LT  , x: int    , y: int    ) => x < y
    case Triple(nme.LT  , x: long   , y: long   ) => x < y
    case Triple(nme.LT  , x: float  , y: float  ) => x < y
    case Triple(nme.LT  , x: double , y: double ) => x < y

    case Triple(nme.GT  , x: int    , y: int    ) => x > y
    case Triple(nme.GT  , x: long   , y: long   ) => x > y
    case Triple(nme.GT  , x: float  , y: float  ) => x > y
    case Triple(nme.GT  , x: double , y: double ) => x > y

    case Triple(nme.LE  , x: int    , y: int    ) => x <= y
    case Triple(nme.LE  , x: long   , y: long   ) => x <= y
    case Triple(nme.LE  , x: float  , y: float  ) => x <= y
    case Triple(nme.LE  , x: double , y: double ) => x <= y

    case Triple(nme.GE  , x: int    , y: int    ) => x >= y
    case Triple(nme.GE  , x: long   , y: long   ) => x >= y
    case Triple(nme.GE  , x: float  , y: float  ) => x >= y
    case Triple(nme.GE  , x: double , y: double ) => x >= y

    case Triple(nme.ADD , x: int    , y: int    ) => x + y
    case Triple(nme.ADD , x: long   , y: long   ) => x + y
    case Triple(nme.ADD , x: float  , y: float  ) => x + y
    case Triple(nme.ADD , x: double , y: double ) => x + y
    case Triple(nme.ADD , x: String , y: String ) => x + y

    case Triple(nme.SUB , x: int    , y: int    ) => x - y
    case Triple(nme.SUB , x: long   , y: long   ) => x - y
    case Triple(nme.SUB , x: float  , y: float  ) => x - y
    case Triple(nme.SUB , x: double , y: double ) => x - y

    case Triple(nme.MUL , x: int    , y: int    ) => x * y
    case Triple(nme.MUL , x: long   , y: long   ) => x * y
    case Triple(nme.MUL , x: float  , y: float  ) => x * y
    case Triple(nme.MUL , x: double , y: double ) => x * y

    case Triple(nme.DIV , x: int    , y: int    ) => x / y
    case Triple(nme.DIV , x: long   , y: long   ) => x / y
    case Triple(nme.DIV , x: float  , y: float  ) => x / y
    case Triple(nme.DIV , x: double , y: double ) => x / y

    case Triple(nme.MOD , x: int    , y: int    ) => x % y
    case Triple(nme.MOD , x: long   , y: long   ) => x % y
    case Triple(nme.MOD , x: float  , y: float  ) => x % y
    case Triple(nme.MOD , x: double , y: double ) => x % y

    case _ => NoValue
  }

  /** Widen constant value to conform to given type */
  private def foldTyped(value: Any, pt: Type): Any = {
    val target = pt.symbol;
    value match {
      case x: byte =>
	if (target == ShortClass) x.asInstanceOf[short]
	else if (target == CharClass) x.asInstanceOf[char]
	else if (target == IntClass) x.asInstanceOf[int]
	else if (target == LongClass) x.asInstanceOf[long]
	else if (target == FloatClass) x.asInstanceOf[float]
	else if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x: short =>
	if (target == IntClass) x.asInstanceOf[int]
	else if (target == LongClass) x.asInstanceOf[long]
	else if (target == FloatClass) x.asInstanceOf[float]
	else if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x: char =>
	if (target == IntClass) x.asInstanceOf[int]
	else if (target == LongClass) x.asInstanceOf[long]
	else if (target == FloatClass) x.asInstanceOf[float]
	else if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x: int =>
	if (target == ByteClass && -128 <= x && x <= 127) x.asInstanceOf[byte]
	else if (target == ShortClass && -32768 <= x && x <= 32767) x.asInstanceOf[short]
	else if (target == CharClass && 0 <= x && x <= 65535) x.asInstanceOf[char]
	else if (target == LongClass) x.asInstanceOf[long]
	else if (target == FloatClass) x.asInstanceOf[float]
	else if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x: long =>
	if (target == FloatClass) x.asInstanceOf[float]
	else if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x: float =>
	if (target == DoubleClass) x.asInstanceOf[double]
	else NoValue
      case x =>
	NoValue
    }
  }

  def literalType(value: Any): Type =
    if (value.isInstanceOf[unit]) UnitClass.tpe
    else if (value.isInstanceOf[boolean]) BooleanClass.tpe
    else if (value.isInstanceOf[byte]) ByteClass.tpe
    else if (value.isInstanceOf[short]) ShortClass.tpe
    else if (value.isInstanceOf[char]) CharClass.tpe
    else if (value.isInstanceOf[int]) IntClass.tpe
    else if (value.isInstanceOf[long]) LongClass.tpe
    else if (value.isInstanceOf[float]) FloatClass.tpe
    else if (value.isInstanceOf[double]) DoubleClass.tpe
    else if (value.isInstanceOf[String]) StringClass.tpe
    else if (value == null) AllRefClass.tpe
    else throw new FatalError("unexpected literal value: " + value);
}