diff options
author | Adriaan Moors <adriaan@lightbend.com> | 2016-08-12 16:24:47 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-12 16:24:47 -0700 |
commit | 3e0b2c2b14bdc26a40887af7a375077565f004b3 (patch) | |
tree | 9886fbcfc6edc3ec069fdf2994cfc1694e4640c2 /src | |
parent | 618d42c747955a43557655bdc0c4281fec5a7923 (diff) | |
parent | 131402fd5fe8c064ef5cfffbe568507cbdf37990 (diff) | |
download | scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.gz scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.bz2 scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.zip |
Merge pull request #5321 from retronym/topic/lock-down-deserialize
SD-193 Lock down lambda deserialization
Diffstat (limited to 'src')
9 files changed, 55 insertions, 50 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index d5c4b5e201..6f9682f434 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -14,6 +14,7 @@ import scala.reflect.internal.Flags import scala.tools.asm import GenBCode._ import BackendReporting._ +import scala.collection.mutable import scala.tools.asm.Opcodes import scala.tools.asm.tree.{MethodInsnNode, MethodNode} import scala.tools.nsc.backend.jvm.BCodeHelpers.{InvokeStyle, TestOp} @@ -1349,7 +1350,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { val markers = if (addScalaSerializableMarker) classBTypeFromSymbol(definitions.SerializableClass).toASMType :: Nil else Nil visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, isSerializable, markers) if (isSerializable) - indyLambdaHosts += cnode.name + addIndyLambdaImplMethod(cnode.name, implMethodHandle :: Nil) } } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 1bff8519ec..d4d532f4df 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -112,14 +112,6 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { gen(cd.impl) - val shouldAddLambdaDeserialize = ( - settings.target.value == "jvm-1.8" - && settings.Ydelambdafy.value == "method" - && indyLambdaHosts.contains(cnode.name)) - - if (shouldAddLambdaDeserialize) - backendUtils.addLambdaDeserialize(cnode) - cnode.visitAttribute(thisBType.inlineInfoAttribute.get) if (AsmUtils.traceClassEnabled && cnode.name.contains(AsmUtils.traceClassPattern)) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala index e04e73304f..573dabcafb 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala @@ -122,7 +122,17 @@ abstract class BTypes { * inlining: when inlining an indyLambda instruction into a class, we need to make sure the class * has the method. */ - val indyLambdaHosts: mutable.Set[InternalName] = recordPerRunCache(mutable.Set.empty) + val indyLambdaImplMethods: mutable.AnyRefMap[InternalName, mutable.LinkedHashSet[asm.Handle]] = recordPerRunCache(mutable.AnyRefMap()) + def addIndyLambdaImplMethod(hostClass: InternalName, handle: Seq[asm.Handle]): Unit = { + if (handle.nonEmpty) + indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle + } + def getIndyLambdaImplMethods(hostClass: InternalName): Iterable[asm.Handle] = { + indyLambdaImplMethods.getOrNull(hostClass) match { + case null => Nil + case xs => xs + } + } /** * Obtain the BType for a type descriptor or internal name. For class descriptors, the ClassBType diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala index c2010d2828..acb950929f 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala @@ -283,7 +283,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) { List( coreBTypes.jliMethodHandlesLookupRef, coreBTypes.StringRef, - coreBTypes.jliMethodTypeRef + coreBTypes.jliMethodTypeRef, + ArrayBType(jliMethodHandleRef) ), coreBTypes.jliCallSiteRef ).descriptor, diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala index 584b11d4ed..0a54767f76 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala @@ -266,6 +266,9 @@ abstract class GenBCode extends BCodeSyncAndTry { try { localOptimizations(item.plain) setInnerClasses(item.plain) + val lambdaImplMethods = getIndyLambdaImplMethods(item.plain.name) + if (lambdaImplMethods.nonEmpty) + backendUtils.addLambdaDeserialize(item.plain, lambdaImplMethods) setInnerClasses(item.mirror) setInnerClasses(item.bean) addToQ3(item) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala index 83615abc31..e25b55e7ab 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala @@ -76,7 +76,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { * host a static field in the enclosing class. This allows us to add this method to interfaces * that define lambdas in default methods. */ - def addLambdaDeserialize(classNode: ClassNode): Unit = { + def addLambdaDeserialize(classNode: ClassNode, implMethods: Iterable[Handle]): Unit = { val cw = classNode // Make sure to reference the ClassBTypes of all types that are used in the code generated @@ -92,7 +92,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null) mv.visitCode() mv.visitVarInsn(ALOAD, 0) - mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle) + mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods.toArray: _*) mv.visitInsn(ARETURN) mv.visitEnd() } @@ -101,19 +101,19 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { /** * Clone the instructions in `methodNode` into a new [[InsnList]], mapping labels according to * the `labelMap`. Returns the new instruction list and a map from old to new instructions, and - * a boolean indicating if the instruction list contains an instantiation of a serializable SAM - * type. + * a list of lambda implementation methods references by invokedynamic[LambdaMetafactory] for a + * serializable SAM types. */ - def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], Boolean) = { + def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], List[Handle]) = { val javaLabelMap = labelMap.asJava val result = new InsnList var map = Map.empty[AbstractInsnNode, AbstractInsnNode] - var hasSerializableClosureInstantiation = false + var inlinedTargetHandles = mutable.ListBuffer[Handle]() for (ins <- methodNode.instructions.iterator.asScala) { - if (!hasSerializableClosureInstantiation) ins match { + ins match { case callGraph.LambdaMetaFactoryCall(indy, _, _, _) => indy.bsmArgs match { - case Array(_, _, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 => - hasSerializableClosureInstantiation = true + case Array(_, targetHandle: Handle, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 => + inlinedTargetHandles += targetHandle case _ => } case _ => @@ -124,7 +124,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { map += ((ins, cloned)) } } - (result, map, hasSerializableClosureInstantiation) + (result, map, inlinedTargetHandles.toList) } def getBoxedUnit: FieldInsnNode = new FieldInsnNode(GETSTATIC, srBoxedUnitRef.internalName, "UNIT", srBoxedUnitRef.descriptor) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index 50dd65c56c..b7523bbf06 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -277,7 +277,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { } case _ => false } - val (clonedInstructions, instructionMap, hasSerializableClosureInstantiation) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile) + val (clonedInstructions, instructionMap, targetHandles) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile) // local vars in the callee are shifted by the number of locals at the callsite val localVarShift = callsiteMethod.maxLocals @@ -405,10 +405,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { callsiteMethod.maxStack = math.max(callsiteMethod.maxStack, math.max(stackHeightAtNullCheck, maxStackOfInlinedCode)) - if (hasSerializableClosureInstantiation && !indyLambdaHosts(callsiteClass.internalName)) { - indyLambdaHosts += callsiteClass.internalName - addLambdaDeserialize(byteCodeRepository.classNode(callsiteClass.internalName).get) - } + addIndyLambdaImplMethod(callsiteClass.internalName, targetHandles) callGraph.addIfMissing(callee, calleeDeclarationClass) diff --git a/src/library/scala/runtime/LambdaDeserialize.java b/src/library/scala/runtime/LambdaDeserialize.java index e239debf25..4c5198cc48 100644 --- a/src/library/scala/runtime/LambdaDeserialize.java +++ b/src/library/scala/runtime/LambdaDeserialize.java @@ -2,28 +2,37 @@ package scala.runtime; import java.lang.invoke.*; -import java.util.Arrays; import java.util.HashMap; public final class LambdaDeserialize { + public static final MethodType DESERIALIZE_LAMBDA_MT = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", LambdaDeserialize.class.getClassLoader()); private MethodHandles.Lookup lookup; private final HashMap<String, MethodHandle> cache = new HashMap<>(); private final LambdaDeserializer$ l = LambdaDeserializer$.MODULE$; + private final HashMap<String, MethodHandle> targetMethodMap; - private LambdaDeserialize(MethodHandles.Lookup lookup) { + private LambdaDeserialize(MethodHandles.Lookup lookup, MethodHandle[] targetMethods) { this.lookup = lookup; + targetMethodMap = new HashMap<>(targetMethods.length); + for (MethodHandle targetMethod : targetMethods) { + MethodHandleInfo info = lookup.revealDirect(targetMethod); + String key = nameAndDescriptorKey(info.getName(), info.getMethodType().toMethodDescriptorString()); + targetMethodMap.put(key, targetMethod); + } } public Object deserializeLambda(SerializedLambda serialized) { - return l.deserializeLambda(lookup, cache, serialized); + return l.deserializeLambda(lookup, cache, targetMethodMap, serialized); } public static CallSite bootstrap(MethodHandles.Lookup lookup, String invokedName, - MethodType invokedType) throws Throwable { - MethodType type = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", lookup.getClass().getClassLoader()); - MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", type); - MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup)).asType(invokedType); + MethodType invokedType, MethodHandle... targetMethods) throws Throwable { + MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", DESERIALIZE_LAMBDA_MT); + MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup, targetMethods)).asType(invokedType); return new ConstantCallSite(exact); } + public static String nameAndDescriptorKey(String name, String descriptor) { + return name + descriptor; + } } diff --git a/src/library/scala/runtime/LambdaDeserializer.scala b/src/library/scala/runtime/LambdaDeserializer.scala index a6e08e6e61..25f41fd049 100644 --- a/src/library/scala/runtime/LambdaDeserializer.scala +++ b/src/library/scala/runtime/LambdaDeserializer.scala @@ -31,10 +31,13 @@ object LambdaDeserializer { * member of the anonymous class created by `LambdaMetaFactory`. * @return An instance of the functional interface */ - def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = { + def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], + targetMethodMap: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = { + assert(targetMethodMap != null) def slashDot(name: String) = name.replaceAll("/", ".") val loader = lookup.lookupClass().getClassLoader val implClass = loader.loadClass(slashDot(serialized.getImplClass)) + val key = LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName, serialized.getImplMethodSignature) def makeCallSite: CallSite = { import serialized._ @@ -69,7 +72,11 @@ object LambdaDeserializer { // Lookup the implementation method val implMethod: MethodHandle = try { - findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig) + if (targetMethodMap.containsKey(key)) { + targetMethodMap.get(key) + } else { + throw new IllegalArgumentException("Illegal lambda deserialization") + } } catch { case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e) } @@ -91,7 +98,6 @@ object LambdaDeserializer { ) } - val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature val factory: MethodHandle = if (cache == null) { makeCallSite.getTarget } else cache.synchronized{ @@ -117,18 +123,4 @@ object LambdaDeserializer { // is cleaner if we uniformly add a single marker, so I'm leaving it in place. "java.io.Serializable" } - - private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_], - name: String, signature: MethodType): MethodHandle = { - kind match { - case MethodHandleInfo.REF_invokeStatic => - lookup.findStatic(owner, name, signature) - case MethodHandleInfo.REF_newInvokeSpecial => - lookup.findConstructor(owner, signature) - case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface => - lookup.findVirtual(owner, name, signature) - case MethodHandleInfo.REF_invokeSpecial => - lookup.findSpecial(owner, name, signature, owner) - } - } } |