diff options
6 files changed, 68 insertions, 56 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala index 0845e440d7..bff58b426e 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala @@ -124,12 +124,13 @@ abstract class BTypes { */ val indyLambdaImplMethods: mutable.AnyRefMap[InternalName, mutable.LinkedHashSet[asm.Handle]] = recordPerRunCache(mutable.AnyRefMap()) def addIndyLambdaImplMethod(hostClass: InternalName, handle: Seq[asm.Handle]): Unit = { - indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle + if (handle.nonEmpty) + indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle } - def getIndyLambdaImplMethods(hostClass: InternalName): List[asm.Handle] = { + def getIndyLambdaImplMethods(hostClass: InternalName): Iterable[asm.Handle] = { indyLambdaImplMethods.getOrNull(hostClass) match { case null => Nil - case xs => xs.toList.distinct + case xs => xs } } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala index 1dbb18722f..acb950929f 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala @@ -289,19 +289,6 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) { coreBTypes.jliCallSiteRef ).descriptor, /* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get) - lazy val lambdaDeserializeAddTargets = - new scala.tools.asm.Handle(scala.tools.asm.Opcodes.H_INVOKESTATIC, - coreBTypes.srLambdaDeserialize.internalName, "bootstrapAddTargets", - MethodBType( - List( - coreBTypes.jliMethodHandlesLookupRef, - coreBTypes.StringRef, - coreBTypes.jliMethodTypeRef, - ArrayBType(coreBTypes.jliMethodHandleRef) - ), - coreBTypes.jliCallSiteRef - ).descriptor, - /* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get) } /** diff --git a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala index d85d85003d..e25b55e7ab 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala @@ -76,7 +76,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { * host a static field in the enclosing class. This allows us to add this method to interfaces * that define lambdas in default methods. */ - def addLambdaDeserialize(classNode: ClassNode, implMethods: List[Handle]): Unit = { + def addLambdaDeserialize(classNode: ClassNode, implMethods: Iterable[Handle]): Unit = { val cw = classNode // Make sure to reference the ClassBTypes of all types that are used in the code generated @@ -87,13 +87,12 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { val nilLookupDesc = MethodBType(Nil, jliMethodHandlesLookupRef).descriptor val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectRef).descriptor - val addTargetMethodsObjDesc = MethodBType(ObjectRef :: Nil, UNIT).descriptor { val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null) mv.visitCode() mv.visitVarInsn(ALOAD, 0) - mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods: _*) + mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods.toArray: _*) mv.visitInsn(ARETURN) mv.visitEnd() } @@ -102,8 +101,8 @@ class BackendUtils[BT <: BTypes](val btypes: BT) { /** * Clone the instructions in `methodNode` into a new [[InsnList]], mapping labels according to * the `labelMap`. Returns the new instruction list and a map from old to new instructions, and - * a boolean indicating if the instruction list contains an instantiation of a serializable SAM - * type. + * a list of lambda implementation methods references by invokedynamic[LambdaMetafactory] for a + * serializable SAM types. */ def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], List[Handle]) = { val javaLabelMap = labelMap.asJava diff --git a/src/library/scala/runtime/LambdaDeserialize.java b/src/library/scala/runtime/LambdaDeserialize.java index a3df868517..4c5198cc48 100644 --- a/src/library/scala/runtime/LambdaDeserialize.java +++ b/src/library/scala/runtime/LambdaDeserialize.java @@ -2,14 +2,10 @@ package scala.runtime; import java.lang.invoke.*; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; public final class LambdaDeserialize { public static final MethodType DESERIALIZE_LAMBDA_MT = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", LambdaDeserialize.class.getClassLoader()); - public static final MethodType ADD_TARGET_METHODS_MT = MethodType.fromMethodDescriptorString("([Ljava/lang/invoke/MethodHandle;)V", LambdaDeserialize.class.getClassLoader()); private MethodHandles.Lookup lookup; private final HashMap<String, MethodHandle> cache = new HashMap<>(); @@ -37,6 +33,6 @@ public final class LambdaDeserialize { return new ConstantCallSite(exact); } public static String nameAndDescriptorKey(String name, String descriptor) { - return name + " " + descriptor; + return name + descriptor; } } diff --git a/src/library/scala/runtime/LambdaDeserializer.scala b/src/library/scala/runtime/LambdaDeserializer.scala index eb168fe445..e120f0e308 100644 --- a/src/library/scala/runtime/LambdaDeserializer.scala +++ b/src/library/scala/runtime/LambdaDeserializer.scala @@ -33,6 +33,7 @@ object LambdaDeserializer { */ def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], targetMethodMap: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = { + assert(targetMethodMap != null) def slashDot(name: String) = name.replaceAll("/", ".") val loader = lookup.lookupClass().getClassLoader val implClass = loader.loadClass(slashDot(serialized.getImplClass)) @@ -71,14 +72,10 @@ object LambdaDeserializer { // Lookup the implementation method val implMethod: MethodHandle = try { - if (targetMethodMap != null) { - if (targetMethodMap.containsKey(key)) { - targetMethodMap.get(key) - } else { - throw new IllegalArgumentException("Illegal lambda deserialization") - } + if (targetMethodMap.containsKey(key)) { + targetMethodMap.get(key) } else { - findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig) + throw new IllegalArgumentException("Illegal lambda deserialization") } } catch { case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e) @@ -124,18 +121,4 @@ object LambdaDeserializer { // is cleaner if we uniformly add a single marker, so I'm leaving it in place. "java.io.Serializable" } - - private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_], - name: String, signature: MethodType): MethodHandle = { - kind match { - case MethodHandleInfo.REF_invokeStatic => - lookup.findStatic(owner, name, signature) - case MethodHandleInfo.REF_newInvokeSpecial => - lookup.findConstructor(owner, signature) - case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface => - lookup.findVirtual(owner, name, signature) - case MethodHandleInfo.REF_invokeSpecial => - lookup.findSpecial(owner, name, signature, owner) - } - } } diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java index ba52e979cc..3ed1ae1365 100644 --- a/test/junit/scala/runtime/LambdaDeserializerTest.java +++ b/test/junit/scala/runtime/LambdaDeserializerTest.java @@ -4,9 +4,7 @@ import org.junit.Assert; import org.junit.Test; import java.io.Serializable; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.SerializedLambda; +import java.lang.invoke.*; import java.lang.reflect.Method; import java.util.Arrays; import java.util.HashMap; @@ -85,19 +83,20 @@ public final class LambdaDeserializerTest { public void implMethodNameChanged() { F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); - checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); } @Test public void implMethodSignatureChanged() { F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); - checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); } - private void checkIllegalAccess(SerializedLambda serialized) { + private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) { try { - LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, null, serialized); + HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed); + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested); throw new AssertionError(); } catch (IllegalArgumentException iae) { if (!iae.getMessage().contains("Illegal lambda deserialization")) { @@ -123,6 +122,7 @@ public final class LambdaDeserializerTest { throw new RuntimeException(e); } } + private <A, B> A reconstitute(A f1) { return reconstitute(f1, null); } @@ -130,12 +130,56 @@ public final class LambdaDeserializerTest { @SuppressWarnings("unchecked") private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) { try { - return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, null, writeReplace(f1)); + return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup()); } catch (Exception e) { throw new RuntimeException(e); } } + private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) { + SerializedLambda serialized = writeReplace(f1); + HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized); + return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized); + } + + private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) { + Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader()); + MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader())); + HashMap<String, MethodHandle> allowed = new HashMap<>(); + allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod); + return allowed; + } + + private Class<?> classForName(String className, ClassLoader classLoader) { + try { + return Class.forName(className, true, classLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner, + String name, MethodType signature) { + try { + switch (kind) { + case MethodHandleInfo.REF_invokeStatic: + return lookup.findStatic(owner, name, signature); + case MethodHandleInfo.REF_newInvokeSpecial: + return lookup.findConstructor(owner, signature); + case MethodHandleInfo.REF_invokeVirtual: + case MethodHandleInfo.REF_invokeInterface: + return lookup.findVirtual(owner, name, signature); + case MethodHandleInfo.REF_invokeSpecial: + return lookup.findSpecial(owner, name, signature, owner); + default: + throw new IllegalArgumentException(); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private <A> SerializedLambda writeReplace(A f1) { try { Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); @@ -189,5 +233,7 @@ class LambdaHost { } interface I { - default String i() { return "i"; }; + default String i() { + return "i"; + } } |