summaryrefslogtreecommitdiff
path: root/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala
diff options
context:
space:
mode:
Diffstat (limited to 'test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala')
-rw-r--r--test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala195
1 files changed, 132 insertions, 63 deletions
diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala
index 9fda034a04..b37b5efa7e 100644
--- a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala
+++ b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala
@@ -6,6 +6,7 @@ import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
import scala.collection.generic.Clearable
+import scala.collection.immutable.IntMap
import scala.tools.asm.Opcodes._
import org.junit.Assert._
@@ -21,25 +22,56 @@ import AsmUtils._
import BackendReporting._
import scala.collection.convert.decorateAsScala._
+import scala.tools.testing.ClearAfterClass
-@RunWith(classOf[JUnit4])
-class CallGraphTest {
- val compiler = newCompiler(extraArgs = "-Ybackend:GenBCode -Yopt:inline-global -Yopt-warnings")
- import compiler.genBCode.bTypes._
+object CallGraphTest extends ClearAfterClass.Clearable {
+ var compiler = newCompiler(extraArgs = "-Yopt:inline-global -Yopt-warnings")
+ def clear(): Unit = { compiler = null }
// allows inspecting the caches after a compilation run
- val notPerRun: List[Clearable] = List(classBTypeFromInternalName, byteCodeRepository.classes, callGraph.callsites)
+ val notPerRun: List[Clearable] = List(
+ compiler.genBCode.bTypes.classBTypeFromInternalName,
+ compiler.genBCode.bTypes.byteCodeRepository.compilingClasses,
+ compiler.genBCode.bTypes.byteCodeRepository.parsedClasses,
+ compiler.genBCode.bTypes.callGraph.callsites)
notPerRun foreach compiler.perRunCaches.unrecordCache
+}
+
+@RunWith(classOf[JUnit4])
+class CallGraphTest extends ClearAfterClass {
+ ClearAfterClass.stateToClear = CallGraphTest
+
+ val compiler = CallGraphTest.compiler
+ import compiler.genBCode.bTypes._
+ import callGraph._
- def compile(code: String, allowMessage: StoreReporter#Info => Boolean): List[ClassNode] = {
- notPerRun.foreach(_.clear())
- compileClasses(compiler)(code, allowMessage = allowMessage)
+ def compile(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): List[ClassNode] = {
+ CallGraphTest.notPerRun.foreach(_.clear())
+ compileClasses(compiler)(code, allowMessage = allowMessage).map(c => byteCodeRepository.classNode(c.name).get)
}
def callsInMethod(methodNode: MethodNode): List[MethodInsnNode] = methodNode.instructions.iterator.asScala.collect({
case call: MethodInsnNode => call
}).toList
+ def checkCallsite(call: MethodInsnNode, callsiteMethod: MethodNode, target: MethodNode, calleeDeclClass: ClassBType,
+ safeToInline: Boolean, atInline: Boolean, atNoInline: Boolean, argInfos: IntMap[ArgInfo] = IntMap.empty) = {
+ val callsite = callGraph.callsites(callsiteMethod)(call)
+ try {
+ assert(callsite.callsiteInstruction == call)
+ assert(callsite.callsiteMethod == callsiteMethod)
+ val callee = callsite.callee.get
+ assert(callee.callee == target)
+ assert(callee.calleeDeclarationClass == calleeDeclClass)
+ assert(callee.safeToInline == safeToInline)
+ assert(callee.annotatedInline == atInline)
+ assert(callee.annotatedNoInline == atNoInline)
+ assert(callsite.argInfos == argInfos)
+ } catch {
+ case e: Throwable => println(callsite); throw e
+ }
+ }
+
@Test
def callGraphStructure(): Unit = {
val code =
@@ -83,70 +115,107 @@ class CallGraphTest {
msgCount += 1
ok exists (m.msg contains _)
}
- val List(cCls, cMod, dCls, testCls) = compile(code, checkMsg).map(c => byteCodeRepository.classNode(c.name).get)
+ val List(cCls, cMod, dCls, testCls) = compile(code, checkMsg)
assert(msgCount == 6, msgCount)
- val List(cf1, cf2, cf3, cf4, cf5, cf6, cf7) = cCls.methods.iterator.asScala.filter(_.name.startsWith("f")).toList.sortBy(_.name)
- val List(df1, df3) = dCls.methods.iterator.asScala.filter(_.name.startsWith("f")).toList.sortBy(_.name)
- val g1 = cMod.methods.iterator.asScala.find(_.name == "g1").get
- val List(t1, t2) = testCls.methods.iterator.asScala.filter(_.name.startsWith("t")).toList.sortBy(_.name)
+ val List(cf1, cf2, cf3, cf4, cf5, cf6, cf7) = findAsmMethods(cCls, _.startsWith("f"))
+ val List(df1, df3) = findAsmMethods(dCls, _.startsWith("f"))
+ val g1 = findAsmMethod(cMod, "g1")
+ val List(t1, t2) = findAsmMethods(testCls, _.startsWith("t"))
val List(cf1Call, cf2Call, cf3Call, cf4Call, cf5Call, cf6Call, cf7Call, cg1Call) = callsInMethod(t1)
val List(df1Call, df2Call, df3Call, df4Call, df5Call, df6Call, df7Call, dg1Call) = callsInMethod(t2)
- def checkCallsite(callsite: callGraph.Callsite,
- call: MethodInsnNode, callsiteMethod: MethodNode, target: MethodNode, calleeDeclClass: ClassBType,
- safeToInline: Boolean, atInline: Boolean, atNoInline: Boolean) = try {
- assert(callsite.callsiteInstruction == call)
- assert(callsite.callsiteMethod == callsiteMethod)
- val callee = callsite.callee.get
- assert(callee.callee == target)
- assert(callee.calleeDeclarationClass == calleeDeclClass)
- assert(callee.safeToInline == safeToInline)
- assert(callee.annotatedInline == atInline)
- assert(callee.annotatedNoInline == atNoInline)
-
- assert(callsite.argInfos == List()) // not defined yet
- } catch {
- case e: Throwable => println(callsite); throw e
- }
-
val cClassBType = classBTypeFromClassNode(cCls)
val cMClassBType = classBTypeFromClassNode(cMod)
val dClassBType = classBTypeFromClassNode(dCls)
- checkCallsite(callGraph.callsites(cf1Call),
- cf1Call, t1, cf1, cClassBType, false, false, false)
- checkCallsite(callGraph.callsites(cf2Call),
- cf2Call, t1, cf2, cClassBType, true, false, false)
- checkCallsite(callGraph.callsites(cf3Call),
- cf3Call, t1, cf3, cClassBType, false, true, false)
- checkCallsite(callGraph.callsites(cf4Call),
- cf4Call, t1, cf4, cClassBType, true, true, false)
- checkCallsite(callGraph.callsites(cf5Call),
- cf5Call, t1, cf5, cClassBType, false, false, true)
- checkCallsite(callGraph.callsites(cf6Call),
- cf6Call, t1, cf6, cClassBType, true, false, true)
- checkCallsite(callGraph.callsites(cf7Call),
- cf7Call, t1, cf7, cClassBType, false, true, true)
- checkCallsite(callGraph.callsites(cg1Call),
- cg1Call, t1, g1, cMClassBType, true, false, false)
-
- checkCallsite(callGraph.callsites(df1Call),
- df1Call, t2, df1, dClassBType, false, true, false)
- checkCallsite(callGraph.callsites(df2Call),
- df2Call, t2, cf2, cClassBType, true, false, false)
- checkCallsite(callGraph.callsites(df3Call),
- df3Call, t2, df3, dClassBType, true, false, false)
- checkCallsite(callGraph.callsites(df4Call),
- df4Call, t2, cf4, cClassBType, true, true, false)
- checkCallsite(callGraph.callsites(df5Call),
- df5Call, t2, cf5, cClassBType, false, false, true)
- checkCallsite(callGraph.callsites(df6Call),
- df6Call, t2, cf6, cClassBType, true, false, true)
- checkCallsite(callGraph.callsites(df7Call),
- df7Call, t2, cf7, cClassBType, false, true, true)
- checkCallsite(callGraph.callsites(dg1Call),
- dg1Call, t2, g1, cMClassBType, true, false, false)
+ checkCallsite(cf1Call, t1, cf1, cClassBType, false, false, false)
+ checkCallsite(cf2Call, t1, cf2, cClassBType, true, false, false)
+ checkCallsite(cf3Call, t1, cf3, cClassBType, false, true, false)
+ checkCallsite(cf4Call, t1, cf4, cClassBType, true, true, false)
+ checkCallsite(cf5Call, t1, cf5, cClassBType, false, false, true)
+ checkCallsite(cf6Call, t1, cf6, cClassBType, true, false, true)
+ checkCallsite(cf7Call, t1, cf7, cClassBType, false, true, true)
+ checkCallsite(cg1Call, t1, g1, cMClassBType, true, false, false)
+
+ checkCallsite(df1Call, t2, df1, dClassBType, false, true, false)
+ checkCallsite(df2Call, t2, cf2, cClassBType, true, false, false)
+ checkCallsite(df3Call, t2, df3, dClassBType, true, false, false)
+ checkCallsite(df4Call, t2, cf4, cClassBType, true, true, false)
+ checkCallsite(df5Call, t2, cf5, cClassBType, false, false, true)
+ checkCallsite(df6Call, t2, cf6, cClassBType, true, false, true)
+ checkCallsite(df7Call, t2, cf7, cClassBType, false, true, true)
+ checkCallsite(dg1Call, t2, g1, cMClassBType, true, false, false)
+ }
+
+ @Test
+ def callerSensitiveNotSafeToInline(): Unit = {
+ val code =
+ """class C {
+ | def m = java.lang.Class.forName("C")
+ |}
+ """.stripMargin
+ val List(c) = compile(code)
+ val m = findAsmMethod(c, "m")
+ val List(fn) = callsInMethod(m)
+ val forNameMeth = byteCodeRepository.methodNode("java/lang/Class", "forName", "(Ljava/lang/String;)Ljava/lang/Class;").get._1
+ val classTp = classBTypeFromInternalName("java/lang/Class")
+ val r = callGraph.callsites(m)(fn)
+ checkCallsite(fn, m, forNameMeth, classTp, safeToInline = false, atInline = false, atNoInline = false)
+ }
+
+ @Test
+ def checkArgInfos(): Unit = {
+ val code =
+ """abstract class C {
+ | def h(f: Int => Int): Int = f(1)
+ | def t1 = h(x => x + 1)
+ | def t2(i: Int, f: Int => Int, z: Int) = h(f) + i - z
+ | def t3(f: Int => Int) = h(x => f(x + 1))
+ |}
+ |trait D {
+ | def iAmASam(x: Int): Int
+ | def selfSamCall = iAmASam(10)
+ |}
+ |""".stripMargin
+ val List(c, d) = compile(code)
+
+ def callIn(m: String) = callGraph.callsites.find(_._1.name == m).get._2.values.head
+ val t1h = callIn("t1")
+ assertEquals(t1h.argInfos.toList, List((1, FunctionLiteral)))
+
+ val t2h = callIn("t2")
+ assertEquals(t2h.argInfos.toList, List((1, ForwardedParam(2))))
+
+ val t3h = callIn("t3")
+ assertEquals(t3h.argInfos.toList, List((1, FunctionLiteral)))
+
+ val selfSamCall = callIn("selfSamCall")
+ assertEquals(selfSamCall.argInfos.toList, List((0,ForwardedParam(0))))
+ }
+
+ @Test
+ def argInfoAfterInlining(): Unit = {
+ val code =
+ """class C {
+ | def foo(f: Int => Int) = f(1) // not inlined
+ | @inline final def bar(g: Int => Int) = foo(g) // forwarded param 1
+ | @inline final def baz = foo(x => x + 1) // literal
+ |
+ | def t1 = bar(x => x + 1) // call to foo should have argInfo literal
+ | def t2(x: Int, f: Int => Int) = x + bar(f) // call to foo should have argInfo forwarded param 2
+ | def t3 = baz // call to foo should have argInfo literal
+ | def someFun: Int => Int = null
+ | def t4(x: Int) = x + bar(someFun) // call to foo has empty argInfo
+ |}
+ """.stripMargin
+
+ compile(code)
+ def callIn(m: String) = callGraph.callsites.find(_._1.name == m).get._2.values.head
+ assertEquals(callIn("t1").argInfos.toList, List((1, FunctionLiteral)))
+ assertEquals(callIn("t2").argInfos.toList, List((1, ForwardedParam(2))))
+ assertEquals(callIn("t3").argInfos.toList, List((1, FunctionLiteral)))
+ assertEquals(callIn("t4").argInfos.toList, Nil)
}
}