aboutsummaryrefslogblamecommitdiff
path: root/test/test/transform/TreeTransformerTest.scala
blob: 2344cafe768337946c8ebf914e00696c1142ee9c (plain) (tree)
1
2
3
4
5
6




                               
                                                                                                       










                                                                                         
                                                       
                                                
                                                                  
       
                                             
                                                                
 
                                               











                                                                            
                                                          

                                                                                                                                           
                                                             
                                                                  

                                             
                                                                   
 
                                               











                                                                                                                                                    
                                                       

                                                                                                                                            
                                                       




                                                                                                                    
                                                                                        
         

                                                                  

                                             
                                                                
 
                                               
 











                                                                                    
                                                        
                                                               




                                                                                                                    
                                             





                                                                                                                    
                                                                                        
         

                                                                  
       
                                                        
                                                               



                                                                                                                    
                                                                                        
         

                                                                  

                                             
                                                                                      
 
                                               












                                                                         
                                                        
                                                           




                                                                                                                    
                                             






                                                                                                        
                                                                                                            
         

                                                                  

                          
                                                        
                             
                                                           













                                                                                                                    
                                             






                                                                                                        
                                                                                                            
         

                                                                  

                                             
                                                                                      
 
                                               






                                                                        
package test.transform


import org.junit.{Assert, Test}
import test.DottyTest
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransformer, MiniPhaseTransform}
import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.core.Contexts.Context


class TreeTransformerTest extends DottyTest {

  @Test
  def shouldReturnSameTreeIfUnchanged = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class EmptyTransform extends MiniPhaseTransform {
        override def phaseName: String = "empty"
        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def transformations = Array(new EmptyTransform)

        override def phaseName: String = "test"
      }
      val transformed = transformer.transform(tree)

      Assert.assertTrue("returns same tree if unmodified",
        tree eq transformed
      )
  }

  @Test
  def canReplaceConstant = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class ConstantTransform extends MiniPhaseTransform {

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = tpd.Literal(Constant(2))
        override def phaseName: String = "canReplaceConstant"
        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def transformations = Array(new ConstantTransform)

        override def phaseName: String = "test"
      }
      val transformed = transformer.transform(tree)

      Assert.assertTrue("returns same tree if unmodified",
        transformed.toString.contains("List(ValDef(Modifiers(,,List()),d,TypeTree[TypeRef(ThisType(module class scala),Int)],Literal(Constant(2)))")
      )
  }

  @Test
  def canOverwrite = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class Transformation extends MiniPhaseTransform {

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = tpd.Literal(Constant(-1))
        override def phaseName: String = "canOverwrite"

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-1))"
          )
          tpd.cpy.ValDef(tree)(tree.mods, tree.name, tree.tpt, tpd.Literal(Constant(2)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def transformations = Array(new Transformation)

        override def phaseName: String = "test"

      }
      val tr = transformer.transform(tree).toString

      Assert.assertTrue("node can rewrite children",
        tr.contains("Literal(Constant(2))") && !tr.contains("Literal(Constant(-1))")
      )
  }

  @Test
  def transformationOrder = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      class Transformation1 extends MiniPhaseTransform {
        override def phaseName: String = "transformationOrder1"

        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          Assert.assertTrue("correct constant",
            tree.const.toString == "Constant(1)"
          )
          tpd.cpy.Literal(tree)(Constant(-1))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-1))"
          )
          tpd.cpy.ValDef(tree)(tree.mods, tree.name, tree.tpt, tpd.Literal(Constant(2)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      class Transformation2 extends MiniPhaseTransform {
        override def phaseName: String = "transformationOrder2"
        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.ValDef = {
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(2))"
          )
          tpd.cpy.ValDef(tree)(tree.mods, tree.name, tree.tpt, tpd.Literal(Constant(3)))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def transformations = Array(new Transformation1, new Transformation2)

        override def phaseName: String = "test"
      }
      val tr = transformer.transform(tree).toString

      Assert.assertTrue("node can rewrite children",
        tr.contains("Literal(Constant(3))")
      )
  }

  @Test
  def invocationCount = checkCompile("frontend", "class A{ val d = 1}") {
    (tree, context) =>
      implicit val ctx = context
      var transformed1 = 0
      class Transformation1 extends MiniPhaseTransform {
        override def phaseName: String = "invocationCount1"
        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          transformed1 += 1
          Assert.assertTrue("correct constant",
            tree.const.toString == "Constant(1)"
          )
          tpd.cpy.Literal(tree)(Constant(-1))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo) = {
          transformed1 += 1
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-3))"
          )
          tpd.cpy.ValDef(tree)(tree.mods, tree.name, tree.tpt, transformFollowing(tpd.Literal(Constant(2))))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      var transformed2 = 0
      class Transformation2 extends MiniPhaseTransform {
        var constantsSeen = 0
        override def phaseName: String = "invocationCount2"
        override def transformLiteral(tree: tpd.Literal)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
          transformed2 += 1
          constantsSeen match {
            case 0 =>
              Assert.assertTrue("correct constant",
                tree.const.toString == "Constant(-1)"
              )
            case 1 =>
              Assert.assertTrue("correct constant",
                tree.const.toString == "Constant(2)"
              )
            case _ => Assert.fail("to many constants seen")
          }
          constantsSeen += 1
          tpd.cpy.Literal(tree)(Constant(-3))
        }

        override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo) = {
          transformed2 += 1
          Assert.assertTrue("transformation of children succeeded",
            tree.rhs.toString == "Literal(Constant(-3))"
          )
          transformFollowing(tpd.cpy.ValDef(tree)(tree.mods, tree.name, tree.tpt, tpd.Literal(Constant(3))))
        }

        init(ctx, ctx.period.firstPhaseId, ctx.period.lastPhaseId)
      }
      val transformer = new TreeTransformer {
        override def transformations = Array(new Transformation1, new Transformation2)

        override def phaseName: String = "test"
      }
      val tr = transformer.transform(tree).toString
      Assert.assertTrue("transformations aren't invoked multiple times",
        transformed1 == 2 && transformed2 == 3
      )
  }
}