aboutsummaryrefslogblamecommitdiff
path: root/compiler/test/dotty/tools/dotc/transform/TreeTransformerTest.scala
blob: d72980d80b64143c0d956305b94974c230a6c986 (plain) (tree)
1
2
3
4
5
6
7
8
9


                   

                               



                                                                            






                                                                                         
                                                       
                                                
                                                                  
       
                                             
                                                           
 
                                               
       
                                                        





                                                          
                                         


                                                                            
                                                          

                                                                                                                                           
                                                             
                                                                  

                                             
                                                              
 
                                               
       
                                                        









                                                                                                                                                    
                                                       

                                                                                                                                            
                                                       




                                                                                                                    
                                                              
         

                                                                  

                                             
                                                           
 
                                               
 
       
                                                        









                                                                                    
                                                        
                                                               




                                                                                                                    
                                             





                                                                                                                    
                                                              
         

                                                                  
       
                                                        
                                                               



                                                                                                                    
                                                              
         

                                                                  

                                             
                                                                                 
 
                                               
       
                                                        










                                                                         
                                                        
                                                           




                                                                                                                    
                                             






                                                                                                        
                                                                                  
         

                                                                  

                          
                                                        
                             
                                                           













                                                                                                                    
                                             






                                                                                                        
                                                                                  
         

                                                                  

                                             
                                                                                 
 
                                               
       
                                                        




                                                                        
package dotty.tools
package dotc
package transform

import org.junit.{Assert, Test}
import TreeTransforms.{TransformerInfo, TreeTransformer, MiniPhaseTransform}
import ast.tpd
import core.Constants.Constant
import core.Contexts.Context

class TreeTransformerTest extends DottyTest {

  @Test
  def shouldReturnSameTreeIfUnchanged = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class EmptyTransform extends MiniPhaseTransform {
        override def phaseName: String = "empty"
        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def miniPhases = Array(new EmptyTransform)

        override def phaseName: String = "test"
      }
      val transformed = transformer.macroTransform(tree)

      Assert.assertTrue("returns same tree if unmodified",
        tree eq transformed
      )
  }

  // Disabled, awaiting resolution. @Test
  def canReplaceConstant = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class ConstantTransform extends MiniPhaseTransform {

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = tpd.Literal(Constant(2))
        override def phaseName: String = "canReplaceConstant"
        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def miniPhases = Array(new ConstantTransform)

        override def phaseName: String = "test"
      }
      val transformed = transformer.macroTransform(tree)

      Assert.assertTrue("returns same tree if unmodified",
        transformed.toString.contains("List(ValDef(Modifiers(,,List()),d,TypeTree[TypeRef(ThisType(module class scala),Int)],Literal(Constant(2)))")
      )
  }

  @Test
  def canOverwrite = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class Transformation extends MiniPhaseTransform {

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = tpd.Literal(Constant(-1))
        override def phaseName: String = "canOverwrite"

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-1))"
          )
          tpd.cpy.ValDef(tree)(rhs = tpd.Literal(Constant(2)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def miniPhases = Array(new Transformation)

        override def phaseName: String = "test"

      }
      val tr = transformer.macroTransform(tree).toString

      Assert.assertTrue("node can rewrite children",
        tr.contains("Literal(Constant(2))") && !tr.contains("Literal(Constant(-1))")
      )
  }

  @Test
  def transformationOrder = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class Transformation1 extends MiniPhaseTransform {
        override def phaseName: String = "transformationOrder1"

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          Assert.assertTrue("correct constant",
            tree.const.toString == "Constant(1)"
          )
          tpd.cpy.Literal(tree)(Constant(-1))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-1))"
          )
          tpd.cpy.ValDef(tree)(rhs = tpd.Literal(Constant(2)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      class Transformation2 extends MiniPhaseTransform {
        override def phaseName: String = "transformationOrder2"
        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(2))"
          )
          tpd.cpy.ValDef(tree)(rhs = tpd.Literal(Constant(3)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def miniPhases = Array(new Transformation1, new Transformation2)

        override def phaseName: String = "test"
      }
      val tr = transformer.macroTransform(tree).toString

      Assert.assertTrue("node can rewrite children",
        tr.contains("Literal(Constant(3))")
      )
  }

  @Test
  def invocationCount = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      var transformed1 = 0
      class Transformation1 extends MiniPhaseTransform {
        override def phaseName: String = "invocationCount1"
        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          transformed1 += 1
          Assert.assertTrue("correct constant",
            tree.const.toString == "Constant(1)"
          )
          tpd.cpy.Literal(tree)(Constant(-1))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo) = {
          transformed1 += 1
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-3))"
          )
          tpd.cpy.ValDef(tree)(rhs = transformFollowing(tpd.Literal(Constant(2))))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      var transformed2 = 0
      class Transformation2 extends MiniPhaseTransform {
        var constantsSeen = 0
        override def phaseName: String = "invocationCount2"
        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          transformed2 += 1
          constantsSeen match {
            case 0 =>
              Assert.assertTrue("correct constant",
                tree.const.toString == "Constant(-1)"
              )
            case 1 =>
              Assert.assertTrue("correct constant",
                tree.const.toString == "Constant(2)"
              )
            case _ => Assert.fail("to many constants seen")
          }
          constantsSeen += 1
          tpd.cpy.Literal(tree)(Constant(-3))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo) = {
          transformed2 += 1
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-3))"
          )
          transformFollowing(tpd.cpy.ValDef(tree)(rhs = tpd.Literal(Constant(3))))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def miniPhases = Array(new Transformation1, new Transformation2)

        override def phaseName: String = "test"
      }
      val tr = transformer.macroTransform(tree).toString
      Assert.assertTrue("transformations aren't invoked multiple times",
        transformed1 == 2 && transformed2 == 3
      )
  }
}