aboutsummaryrefslogtreecommitdiff
path: root/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala
blob: c423089d07a4a49c22d289d5b5954c8e5528bc5f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
package dotty.tools
package backend.jvm

import dotc.core.Contexts.{Context, ContextBase}
import dotc.core.Phases.Phase
import dotc.Compiler

import scala.reflect.io.{VirtualDirectory => Directory}
import scala.tools.asm
import asm._
import asm.tree._
import scala.collection.JavaConverters._

import io.JavaClassPath
import scala.collection.JavaConverters._
import scala.tools.asm.{ClassWriter, ClassReader}
import scala.tools.asm.tree._
import java.io.{File => JFile, InputStream}

class TestGenBCode(val outDir: String) extends GenBCode {
  override def phaseName: String = "testGenBCode"
  val virtualDir = new Directory(outDir, None)
  override def outputDir(implicit ctx: Context) = virtualDir
}

trait DottyBytecodeTest extends DottyTest {
  import AsmNode._
  import ASMConverters._

  protected object Opcode {
    val newarray       = 188
    val anewarray      = 189
    val multianewarray = 197

    val boolean        = 4
    val char           = 5
    val float          = 6
    val double         = 7
    val byte           = 8
    val short          = 9
    val int            = 10
    val long           = 11

    val boxedUnit      = "scala/runtime/BoxedUnit"
    val javaString     = "java/lang/String"
  }

  private def bCodeCheckingComp(phase: TestGenBCode)(check: Directory => Unit) =
    new Compiler {
      override def phases = {
        val updatedPhases = {
          def replacePhase: Phase => Phase =
            { p => if (p.phaseName == "genBCode") phase else p }

          for (phaseList <- super.phases) yield phaseList.map(replacePhase)
        }

        val checkerPhase = List(List(new Phase {
          def phaseName = "assertionChecker"
          override def run(implicit ctx: Context): Unit =
            check(phase.virtualDir)
        }))

        updatedPhases ::: checkerPhase
      }
    }

  private def outPath(obj: Any) =
      "/genBCodeTest" + math.abs(obj.hashCode) + System.currentTimeMillis

  /** Checks source code from raw string */
  def checkBCode(source: String)(assertion: Directory => Unit) = {
    val comp = bCodeCheckingComp(new TestGenBCode(outPath(source)))(assertion)
    comp.rootContext(ctx)
    comp.newRun.compile(source)
  }

  /** Checks actual _files_ referenced in `sources` list */
  def checkBCode(sources: List[String])(assertion: Directory => Unit) = {
    val comp = bCodeCheckingComp(new TestGenBCode(outPath(sources)))(assertion)
    comp.rootContext(ctx)
    comp.newRun.compile(sources)
  }

  protected def loadClassNode(input: InputStream, skipDebugInfo: Boolean = true): ClassNode = {
    val cr = new ClassReader(input)
    val cn = new ClassNode()
    cr.accept(cn, if (skipDebugInfo) ClassReader.SKIP_DEBUG else 0)
    cn
  }

  protected def getMethod(classNode: ClassNode, name: String): MethodNode =
    classNode.methods.asScala.find(_.name == name) getOrElse
      sys.error(s"Didn't find method '$name' in class '${classNode.name}'")

  def diffInstructions(isa: List[Instruction], isb: List[Instruction]): String = {
    val len = Math.max(isa.length, isb.length)
    val sb = new StringBuilder
    if (len > 0 ) {
      val width = isa.map(_.toString.length).max
      val lineWidth = len.toString.length
      (1 to len) foreach { line =>
        val isaPadded = isa.map(_.toString) orElse Stream.continually("")
        val isbPadded = isb.map(_.toString) orElse Stream.continually("")
        val a = isaPadded(line-1)
        val b = isbPadded(line-1)

        sb append (s"""$line${" " * (lineWidth-line.toString.length)} ${if (a==b) "==" else "<>"} $a${" " * (width-a.length)} | $b\n""")
      }
    }
    sb.toString
  }

  /**************************** Comparison Methods ****************************/
  def verifySwitch(method: MethodNode, shouldFail: Boolean = false, debug: Boolean = false): Boolean = {
    val instructions = instructionsFromMethod(method)

    val succ = instructions
      .collect {
        case x: TableSwitch  => x
        case x: LookupSwitch => x
      }
      .length > 0

    if (debug || !succ && !shouldFail || succ && shouldFail)
      instructions.foreach(Console.err.println)

    succ && !shouldFail || shouldFail && !succ
  }

  def sameBytecode(methA: MethodNode, methB: MethodNode) = {
    val isa = instructionsFromMethod(methA)
    val isb = instructionsFromMethod(methB)
    assert(isa == isb, s"Bytecode wasn't same:\n${diffInstructions(isa, isb)}")
  }

  def similarBytecode(
    methA:   MethodNode,
    methB:   MethodNode,
    similar: (List[Instruction], List[Instruction]) => Boolean
  ) = {
    val isa = instructionsFromMethod(methA)
    val isb = instructionsFromMethod(methB)
    assert(
      similar(isa, isb),
      s"""|Bytecode wasn't similar according to the provided predicate:
          |${diffInstructions(isa, isb)}""".stripMargin)
  }

  def sameMethodAndFieldSignatures(clazzA: ClassNode, clazzB: ClassNode) =
    sameCharacteristics(clazzA, clazzB)(_.characteristics)

  /**
   * Same as sameMethodAndFieldSignatures, but ignoring generic signatures.
   * This allows for methods which receive the same descriptor but differing
   * generic signatures. In particular, this happens with value classes, which
   * get a generic signature where a method written in terms of the underlying
   * values does not.
   */
  def sameMethodAndFieldDescriptors(clazzA: ClassNode, clazzB: ClassNode): Unit = {
    val (succ, msg) = sameCharacteristics(clazzA, clazzB)(_.erasedCharacteristics)
    assert(succ, msg)
  }

  private def sameCharacteristics(clazzA: ClassNode, clazzB: ClassNode)(f: AsmNode[_] => String): (Boolean, String) = {
    val ms1 = clazzA.fieldsAndMethods.toIndexedSeq
    val ms2 = clazzB.fieldsAndMethods.toIndexedSeq
    val name1 = clazzA.name
    val name2 = clazzB.name

    if (ms1.length != ms2.length) {
      (false, s"Different member counts in $name1 and $name2")
    } else {
      val msg     = new StringBuilder
      val success = (ms1, ms2).zipped forall { (m1, m2) =>
        val c1 = f(m1)
        val c2 = f(m2).replaceAllLiterally(name2, name1)
        if (c1 == c2)
          msg append (s"[ok] $m1")
        else
          msg append (s"[fail]\n  in $name1: $c1\n  in $name2: $c2")

        c1 == c2
      }

      (success, msg.toString)
    }
  }

  def correctNumberOfNullChecks(expectedChecks: Int, insnList: InsnList) = {
    /** Is given instruction a null check?
     *
     *  This will detect direct null comparison as in
     *    if (x == null) ...
     *  and not indirect as in
     *    val foo = null
     *    if (x == foo) ...
     */
    def isNullCheck(node: asm.tree.AbstractInsnNode): Boolean = {
      val opcode = node.getOpcode
      (opcode == asm.Opcodes.IFNULL) || (opcode == asm.Opcodes.IFNONNULL)
    }
    val actualChecks = insnList.iterator.asScala.count(isNullCheck)
    assert(expectedChecks == actualChecks,
      s"Wrong number of null checks ($actualChecks), expected: $expectedChecks"
    )
  }
}