package scala.tools.nsc package backend.jvm package analysis import scala.annotation.switch import scala.tools.asm.{Opcodes, Handle, Type, Label} import scala.tools.asm.tree._ import scala.tools.asm.tree.analysis.{Frame, BasicInterpreter, Analyzer, Value} import scala.tools.nsc.backend.jvm.BTypes._ import scala.tools.nsc.backend.jvm.opt.BytecodeUtils import scala.tools.nsc.backend.jvm.opt.BytecodeUtils._ import java.lang.invoke.LambdaMetafactory import scala.collection.mutable import scala.collection.convert.decorateAsJava._ import scala.collection.convert.decorateAsScala._ /** * This component hosts tools and utilities used in the backend that require access to a `BTypes` * instance. * * One example is the AsmAnalyzer class, which runs `computeMaxLocalsMaxStack` on the methodNode to * be analyzed. This method in turn lives inside the BTypes assembly because it queries the per-run * cache `maxLocalsMaxStackComputed` defined in there. */ class BackendUtils[BT <: BTypes](val btypes: BT) { import btypes._ import callGraph.ClosureInstantiation /** * A wrapper to make ASM's Analyzer a bit easier to use. */ class AsmAnalyzer[V <: Value](methodNode: MethodNode, classInternalName: InternalName, val analyzer: Analyzer[V] = new Analyzer(new BasicInterpreter)) { computeMaxLocalsMaxStack(methodNode) analyzer.analyze(classInternalName, methodNode) def frameAt(instruction: AbstractInsnNode): Frame[V] = analyzer.frameAt(instruction, methodNode) } /** * See the doc comment on package object `analysis` for a discussion on performance. */ object AsmAnalyzer { // jvm limit is 65535 for both number of instructions and number of locals private def size(method: MethodNode) = method.instructions.size.toLong * method.maxLocals * method.maxLocals // with the limits below, analysis should not take more than one second private val nullnessSizeLimit = 5000l * 600l * 600l // 5000 insns, 600 locals private val basicValueSizeLimit = 9000l * 1000l * 1000l private val sourceValueSizeLimit = 8000l * 950l * 950l def sizeOKForAliasing(method: MethodNode): Boolean = size(method) < nullnessSizeLimit def sizeOKForNullness(method: MethodNode): Boolean = size(method) < nullnessSizeLimit def sizeOKForBasicValue(method: MethodNode): Boolean = size(method) < basicValueSizeLimit def sizeOKForSourceValue(method: MethodNode): Boolean = size(method) < sourceValueSizeLimit } class ProdConsAnalyzer(val methodNode: MethodNode, classInternalName: InternalName) extends AsmAnalyzer(methodNode, classInternalName, new Analyzer(new InitialProducerSourceInterpreter)) with ProdConsAnalyzerImpl class NonLubbingTypeFlowAnalyzer(val methodNode: MethodNode, classInternalName: InternalName) extends AsmAnalyzer(methodNode, classInternalName, new Analyzer(new NonLubbingTypeFlowInterpreter)) /** * Add: * private static java.util.Map $deserializeLambdaCache$ = null * private static Object $deserializeLambda$(SerializedLambda l) { * var cache = $deserializeLambdaCache$ * if (cache eq null) { * cache = new java.util.HashMap() * $deserializeLambdaCache$ = cache * } * return scala.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l); * } */ def addLambdaDeserialize(classNode: ClassNode): Unit = { val cw = classNode import scala.tools.asm.Opcodes._ import btypes.coreBTypes._ // Make sure to reference the ClassBTypes of all types that are used in the code generated // here (e.g. java/util/Map) are initialized. Initializing a ClassBType adds it to the // `classBTypeFromInternalName` map. When writing the classfile, the asm ClassWriter computes // stack map frames and invokes the `getCommonSuperClass` method. This method expects all // ClassBTypes mentioned in the source code to exist in the map. val mapDesc = juMapRef.descriptor val nilLookupDesc = MethodBType(Nil, jliMethodHandlesLookupRef).descriptor val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectRef).descriptor val lookupMapSerlamObjDesc = MethodBType(jliMethodHandlesLookupRef :: juMapRef :: jliSerializedLambdaRef :: Nil, ObjectRef).descriptor { val fv = cw.visitField(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambdaCache$", mapDesc, null, null) fv.visitEnd() } { val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null) mv.visitCode() // javaBinaryName returns the internal name of a class. Also used in BTypesFromsymbols.classBTypeFromSymbol. mv.visitFieldInsn(GETSTATIC, classNode.name, "$deserializeLambdaCache$", mapDesc) mv.visitVarInsn(ASTORE, 1) mv.visitVarInsn(ALOAD, 1) val l0 = new Label() mv.visitJumpInsn(IFNONNULL, l0) mv.visitTypeInsn(NEW, juHashMapRef.internalName) mv.visitInsn(DUP) mv.visitMethodInsn(INVOKESPECIAL, juHashMapRef.internalName, "", "()V", false) mv.visitVarInsn(ASTORE, 1) mv.visitVarInsn(ALOAD, 1) mv.visitFieldInsn(PUTSTATIC, classNode.name, "$deserializeLambdaCache$", mapDesc) mv.visitLabel(l0) mv.visitFieldInsn(GETSTATIC, srLambdaDeserializerRef.internalName, "MODULE$", srLambdaDeserializerRef.descriptor) mv.visitMethodInsn(INVOKESTATIC, jliMethodHandlesRef.internalName, "lookup", nilLookupDesc, false) mv.visitVarInsn(ALOAD, 1) mv.visitVarInsn(ALOAD, 0) mv.visitMethodInsn(INVOKEVIRTUAL, srLambdaDeserializerRef.internalName, "deserializeLambda", lookupMapSerlamObjDesc, false) mv.visitInsn(ARETURN) mv.visitEnd() } } /** * Clone the instructions in `methodNode` into a new [[InsnList]], mapping labels according to * the `labelMap`. Returns the new instruction list and a map from old to new instructions, and * a boolean indicating if the instruction list contains an instantiation of a serializable SAM * type. */ def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode]): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], Boolean) = { val javaLabelMap = labelMap.asJava val result = new InsnList var map = Map.empty[AbstractInsnNode, AbstractInsnNode] var hasSerializableClosureInstantiation = false for (ins <- methodNode.instructions.iterator.asScala) { if (!hasSerializableClosureInstantiation) ins match { case callGraph.LambdaMetaFactoryCall(indy, _, _, _) => indy.bsmArgs match { case Array(_, _, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 => hasSerializableClosureInstantiation = true case _ => } case _ => } val cloned = ins.clone(javaLabelMap) result add cloned map += ((ins, cloned)) } (result, map, hasSerializableClosureInstantiation) } def getBoxedUnit: FieldInsnNode = new FieldInsnNode(Opcodes.GETSTATIC, coreBTypes.srBoxedUnitRef.internalName, "UNIT", coreBTypes.srBoxedUnitRef.descriptor) private val anonfunAdaptedName = """.*\$anonfun\$\d+\$adapted""".r def hasAdaptedImplMethod(closureInit: ClosureInstantiation): Boolean = { BytecodeUtils.isrJFunctionType(Type.getReturnType(closureInit.lambdaMetaFactoryCall.indy.desc).getInternalName) && anonfunAdaptedName.pattern.matcher(closureInit.lambdaMetaFactoryCall.implMethod.getName).matches } /** * Visit the class node and collect all referenced nested classes. */ def collectNestedClasses(classNode: ClassNode): List[ClassBType] = { val innerClasses = mutable.Set.empty[ClassBType] def visitInternalName(internalName: InternalName): Unit = if (internalName != null) { val t = classBTypeFromParsedClassfile(internalName) if (t.isNestedClass.get) innerClasses += t } // either an internal/Name or [[Linternal/Name; -- there are certain references in classfiles // that are either an internal name (without the surrounding `L;`) or an array descriptor // `[Linternal/Name;`. def visitInternalNameOrArrayReference(ref: String): Unit = if (ref != null) { val bracket = ref.lastIndexOf('[') if (bracket == -1) visitInternalName(ref) else if (ref.charAt(bracket + 1) == 'L') visitInternalName(ref.substring(bracket + 2, ref.length - 1)) } // we are only interested in the class references in the descriptor, so we can skip over // primitves and the brackets of array descriptors def visitDescriptor(desc: String): Unit = (desc.charAt(0): @switch) match { case '(' => val internalNames = mutable.ListBuffer.empty[String] var i = 1 while (i < desc.length) { if (desc.charAt(i) == 'L') { val start = i + 1 // skip the L while (desc.charAt(i) != ';') i += 1 internalNames append desc.substring(start, i) } // skips over '[', ')', primitives i += 1 } internalNames foreach visitInternalName case 'L' => visitInternalName(desc.substring(1, desc.length - 1)) case '[' => visitInternalNameOrArrayReference(desc) case _ => // skip over primitive types } def visitConstant(const: AnyRef): Unit = const match { case t: Type => visitDescriptor(t.getDescriptor) case _ => } // in principle we could references to annotation types, as they only end up as strings in the // constant pool, not as class references. however, the java compiler still includes nested // annotation classes in the innerClass table, so we do the same. explained in detail in the // large comment in class BTypes. def visitAnnotation(annot: AnnotationNode): Unit = { visitDescriptor(annot.desc) if (annot.values != null) annot.values.asScala foreach visitConstant } def visitAnnotations(annots: java.util.List[_ <: AnnotationNode]) = if (annots != null) annots.asScala foreach visitAnnotation def visitAnnotationss(annotss: Array[java.util.List[AnnotationNode]]) = if (annotss != null) annotss foreach visitAnnotations def visitHandle(handle: Handle): Unit = { visitInternalNameOrArrayReference(handle.getOwner) visitDescriptor(handle.getDesc) } visitInternalName(classNode.name) innerClasses ++= classBTypeFromParsedClassfile(classNode.name).info.get.nestedClasses visitInternalName(classNode.superName) classNode.interfaces.asScala foreach visitInternalName visitInternalName(classNode.outerClass) visitAnnotations(classNode.visibleAnnotations) visitAnnotations(classNode.visibleTypeAnnotations) visitAnnotations(classNode.invisibleAnnotations) visitAnnotations(classNode.invisibleTypeAnnotations) for (f <- classNode.fields.asScala) { visitDescriptor(f.desc) visitAnnotations(f.visibleAnnotations) visitAnnotations(f.visibleTypeAnnotations) visitAnnotations(f.invisibleAnnotations) visitAnnotations(f.invisibleTypeAnnotations) } for (m <- classNode.methods.asScala) { visitDescriptor(m.desc) visitAnnotations(m.visibleAnnotations) visitAnnotations(m.visibleTypeAnnotations) visitAnnotations(m.invisibleAnnotations) visitAnnotations(m.invisibleTypeAnnotations) visitAnnotationss(m.visibleParameterAnnotations) visitAnnotationss(m.invisibleParameterAnnotations) visitAnnotations(m.visibleLocalVariableAnnotations) visitAnnotations(m.invisibleLocalVariableAnnotations) m.exceptions.asScala foreach visitInternalName for (tcb <- m.tryCatchBlocks.asScala) visitInternalName(tcb.`type`) val iter = m.instructions.iterator() while (iter.hasNext) iter.next() match { case ti: TypeInsnNode => visitInternalNameOrArrayReference(ti.desc) case fi: FieldInsnNode => visitInternalNameOrArrayReference(fi.owner); visitDescriptor(fi.desc) case mi: MethodInsnNode => visitInternalNameOrArrayReference(mi.owner); visitDescriptor(mi.desc) case id: InvokeDynamicInsnNode => visitDescriptor(id.desc); visitHandle(id.bsm); id.bsmArgs foreach visitConstant case ci: LdcInsnNode => visitConstant(ci.cst) case ma: MultiANewArrayInsnNode => visitDescriptor(ma.desc) case _ => } } innerClasses.toList } /** * In order to run an Analyzer, the maxLocals / maxStack fields need to be available. The ASM * framework only computes these values during bytecode generation. * * NOTE 1: as explained in the `analysis` package object, the maxStack value used by the Analyzer * may be smaller than the correct maxStack value in the classfile (Analyzers only use a single * slot for long / double values). The maxStack computed here are correct for running an analyzer, * but not for writing in the classfile. We let the ClassWriter recompute max's. * * NOTE 2: the maxStack value computed here may be larger than the smallest correct value * that would allow running an analyzer, see `InstructionStackEffect.forAsmAnalysisConservative`. * * NOTE 3: the implementation doesn't look at instructions that cannot be reached, it computes * the max local / stack size in the reachable code. These max's work just fine for running an * Analyzer: its implementation also skips over unreachable code in the same way. */ def computeMaxLocalsMaxStack(method: MethodNode): Unit = { import Opcodes._ if (isAbstractMethod(method) || isNativeMethod(method)) { method.maxLocals = 0 method.maxStack = 0 } else if (!maxLocalsMaxStackComputed(method)) { val size = method.instructions.size var maxLocals = (Type.getArgumentsAndReturnSizes(method.desc) >> 2) - (if (isStaticMethod(method)) 1 else 0) var maxStack = 0 // queue of instruction indices where analysis should start var queue = new Array[Int](8) var top = -1 def enq(i: Int): Unit = { if (top == queue.length - 1) { val nq = new Array[Int](queue.length * 2) Array.copy(queue, 0, nq, 0, queue.length) queue = nq } top += 1 queue(top) = i } def deq(): Int = { val r = queue(top) top -= 1 r } val subroutineRetTargets = new mutable.Stack[AbstractInsnNode] // for each instruction in the queue, contains the stack height at this instruction. // once an instruction has been treated, contains -1 to prevent re-enqueuing val stackHeights = new Array[Int](size) def enqInsn(insn: AbstractInsnNode, height: Int): Unit = { enqInsnIndex(method.instructions.indexOf(insn), height) } def enqInsnIndex(insnIndex: Int, height: Int): Unit = { if (insnIndex < size && stackHeights(insnIndex) != -1) { stackHeights(insnIndex) = height enq(insnIndex) } } val tcbIt = method.tryCatchBlocks.iterator() while (tcbIt.hasNext) { val tcb = tcbIt.next() enqInsn(tcb.handler, 1) if (maxStack == 0) maxStack = 1 } enq(0) while (top != -1) { val insnIndex = deq() val insn = method.instructions.get(insnIndex) val initHeight = stackHeights(insnIndex) stackHeights(insnIndex) = -1 // prevent i from being enqueued again if (insn.getOpcode == -1) { // frames, labels, line numbers enqInsnIndex(insnIndex + 1, initHeight) } else { val stackGrowth = InstructionStackEffect.maxStackGrowth(insn) val heightAfter = initHeight + stackGrowth if (heightAfter > maxStack) maxStack = heightAfter // update maxLocals insn match { case v: VarInsnNode => val longSize = if (isSize2LoadOrStore(v.getOpcode)) 1 else 0 maxLocals = math.max(maxLocals, v.`var` + longSize + 1) // + 1 becauase local numbers are 0-based case i: IincInsnNode => maxLocals = math.max(maxLocals, i.`var` + 1) case _ => } insn match { case j: JumpInsnNode => if (j.getOpcode == JSR) { val jsrTargetHeight = heightAfter + 1 if (jsrTargetHeight > maxStack) maxStack = jsrTargetHeight subroutineRetTargets.push(j.getNext) enqInsn(j.label, jsrTargetHeight) } else { enqInsn(j.label, heightAfter) val opc = j.getOpcode if (opc != GOTO) enqInsnIndex(insnIndex + 1, heightAfter) // jump is conditional, so the successor is also a possible control flow target } case l: LookupSwitchInsnNode => var j = 0 while (j < l.labels.size) { enqInsn(l.labels.get(j), heightAfter); j += 1 } enqInsn(l.dflt, heightAfter) case t: TableSwitchInsnNode => var j = 0 while (j < t.labels.size) { enqInsn(t.labels.get(j), heightAfter); j += 1 } enqInsn(t.dflt, heightAfter) case r: VarInsnNode if r.getOpcode == RET => enqInsn(subroutineRetTargets.pop(), heightAfter) case _ => val opc = insn.getOpcode if (opc != ATHROW && !isReturn(insn)) enqInsnIndex(insnIndex + 1, heightAfter) } } } method.maxLocals = maxLocals method.maxStack = maxStack maxLocalsMaxStackComputed += method } } }