summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala7
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala13
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala9
-rw-r--r--src/library/scala/runtime/LambdaDeserialize.java6
-rw-r--r--src/library/scala/runtime/LambdaDeserializer.scala25
-rw-r--r--test/junit/scala/runtime/LambdaDeserializerTest.java64
6 files changed, 68 insertions, 56 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
index 0845e440d7..bff58b426e 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
@@ -124,12 +124,13 @@ abstract class BTypes {
*/
val indyLambdaImplMethods: mutable.AnyRefMap[InternalName, mutable.LinkedHashSet[asm.Handle]] = recordPerRunCache(mutable.AnyRefMap())
def addIndyLambdaImplMethod(hostClass: InternalName, handle: Seq[asm.Handle]): Unit = {
- indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle
+ if (handle.nonEmpty)
+ indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle
}
- def getIndyLambdaImplMethods(hostClass: InternalName): List[asm.Handle] = {
+ def getIndyLambdaImplMethods(hostClass: InternalName): Iterable[asm.Handle] = {
indyLambdaImplMethods.getOrNull(hostClass) match {
case null => Nil
- case xs => xs.toList.distinct
+ case xs => xs
}
}
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
index 1dbb18722f..acb950929f 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
@@ -289,19 +289,6 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) {
coreBTypes.jliCallSiteRef
).descriptor,
/* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get)
- lazy val lambdaDeserializeAddTargets =
- new scala.tools.asm.Handle(scala.tools.asm.Opcodes.H_INVOKESTATIC,
- coreBTypes.srLambdaDeserialize.internalName, "bootstrapAddTargets",
- MethodBType(
- List(
- coreBTypes.jliMethodHandlesLookupRef,
- coreBTypes.StringRef,
- coreBTypes.jliMethodTypeRef,
- ArrayBType(coreBTypes.jliMethodHandleRef)
- ),
- coreBTypes.jliCallSiteRef
- ).descriptor,
- /* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get)
}
/**
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
index d85d85003d..e25b55e7ab 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
@@ -76,7 +76,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
* host a static field in the enclosing class. This allows us to add this method to interfaces
* that define lambdas in default methods.
*/
- def addLambdaDeserialize(classNode: ClassNode, implMethods: List[Handle]): Unit = {
+ def addLambdaDeserialize(classNode: ClassNode, implMethods: Iterable[Handle]): Unit = {
val cw = classNode
// Make sure to reference the ClassBTypes of all types that are used in the code generated
@@ -87,13 +87,12 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
val nilLookupDesc = MethodBType(Nil, jliMethodHandlesLookupRef).descriptor
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectRef).descriptor
- val addTargetMethodsObjDesc = MethodBType(ObjectRef :: Nil, UNIT).descriptor
{
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
mv.visitCode()
mv.visitVarInsn(ALOAD, 0)
- mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods: _*)
+ mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods.toArray: _*)
mv.visitInsn(ARETURN)
mv.visitEnd()
}
@@ -102,8 +101,8 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
/**
* Clone the instructions in `methodNode` into a new [[InsnList]], mapping labels according to
* the `labelMap`. Returns the new instruction list and a map from old to new instructions, and
- * a boolean indicating if the instruction list contains an instantiation of a serializable SAM
- * type.
+ * a list of lambda implementation methods references by invokedynamic[LambdaMetafactory] for a
+ * serializable SAM types.
*/
def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], List[Handle]) = {
val javaLabelMap = labelMap.asJava
diff --git a/src/library/scala/runtime/LambdaDeserialize.java b/src/library/scala/runtime/LambdaDeserialize.java
index a3df868517..4c5198cc48 100644
--- a/src/library/scala/runtime/LambdaDeserialize.java
+++ b/src/library/scala/runtime/LambdaDeserialize.java
@@ -2,14 +2,10 @@ package scala.runtime;
import java.lang.invoke.*;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
import java.util.HashMap;
public final class LambdaDeserialize {
public static final MethodType DESERIALIZE_LAMBDA_MT = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", LambdaDeserialize.class.getClassLoader());
- public static final MethodType ADD_TARGET_METHODS_MT = MethodType.fromMethodDescriptorString("([Ljava/lang/invoke/MethodHandle;)V", LambdaDeserialize.class.getClassLoader());
private MethodHandles.Lookup lookup;
private final HashMap<String, MethodHandle> cache = new HashMap<>();
@@ -37,6 +33,6 @@ public final class LambdaDeserialize {
return new ConstantCallSite(exact);
}
public static String nameAndDescriptorKey(String name, String descriptor) {
- return name + " " + descriptor;
+ return name + descriptor;
}
}
diff --git a/src/library/scala/runtime/LambdaDeserializer.scala b/src/library/scala/runtime/LambdaDeserializer.scala
index eb168fe445..e120f0e308 100644
--- a/src/library/scala/runtime/LambdaDeserializer.scala
+++ b/src/library/scala/runtime/LambdaDeserializer.scala
@@ -33,6 +33,7 @@ object LambdaDeserializer {
*/
def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle],
targetMethodMap: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
+ assert(targetMethodMap != null)
def slashDot(name: String) = name.replaceAll("/", ".")
val loader = lookup.lookupClass().getClassLoader
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
@@ -71,14 +72,10 @@ object LambdaDeserializer {
// Lookup the implementation method
val implMethod: MethodHandle = try {
- if (targetMethodMap != null) {
- if (targetMethodMap.containsKey(key)) {
- targetMethodMap.get(key)
- } else {
- throw new IllegalArgumentException("Illegal lambda deserialization")
- }
+ if (targetMethodMap.containsKey(key)) {
+ targetMethodMap.get(key)
} else {
- findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
+ throw new IllegalArgumentException("Illegal lambda deserialization")
}
} catch {
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
@@ -124,18 +121,4 @@ object LambdaDeserializer {
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
"java.io.Serializable"
}
-
- private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
- name: String, signature: MethodType): MethodHandle = {
- kind match {
- case MethodHandleInfo.REF_invokeStatic =>
- lookup.findStatic(owner, name, signature)
- case MethodHandleInfo.REF_newInvokeSpecial =>
- lookup.findConstructor(owner, signature)
- case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
- lookup.findVirtual(owner, name, signature)
- case MethodHandleInfo.REF_invokeSpecial =>
- lookup.findSpecial(owner, name, signature, owner)
- }
- }
}
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java
index ba52e979cc..3ed1ae1365 100644
--- a/test/junit/scala/runtime/LambdaDeserializerTest.java
+++ b/test/junit/scala/runtime/LambdaDeserializerTest.java
@@ -4,9 +4,7 @@ import org.junit.Assert;
import org.junit.Test;
import java.io.Serializable;
-import java.lang.invoke.MethodHandle;
-import java.lang.invoke.MethodHandles;
-import java.lang.invoke.SerializedLambda;
+import java.lang.invoke.*;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
@@ -85,19 +83,20 @@ public final class LambdaDeserializerTest {
public void implMethodNameChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
- checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
+ checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
}
@Test
public void implMethodSignatureChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
- checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
+ checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
}
- private void checkIllegalAccess(SerializedLambda serialized) {
+ private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) {
try {
- LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, null, serialized);
+ HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed);
+ LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
@@ -123,6 +122,7 @@ public final class LambdaDeserializerTest {
throw new RuntimeException(e);
}
}
+
private <A, B> A reconstitute(A f1) {
return reconstitute(f1, null);
}
@@ -130,12 +130,56 @@ public final class LambdaDeserializerTest {
@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
- return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, null, writeReplace(f1));
+ return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
+ private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) {
+ SerializedLambda serialized = writeReplace(f1);
+ HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized);
+ return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized);
+ }
+
+ private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) {
+ Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader());
+ MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader()));
+ HashMap<String, MethodHandle> allowed = new HashMap<>();
+ allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod);
+ return allowed;
+ }
+
+ private Class<?> classForName(String className, ClassLoader classLoader) {
+ try {
+ return Class.forName(className, true, classLoader);
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner,
+ String name, MethodType signature) {
+ try {
+ switch (kind) {
+ case MethodHandleInfo.REF_invokeStatic:
+ return lookup.findStatic(owner, name, signature);
+ case MethodHandleInfo.REF_newInvokeSpecial:
+ return lookup.findConstructor(owner, signature);
+ case MethodHandleInfo.REF_invokeVirtual:
+ case MethodHandleInfo.REF_invokeInterface:
+ return lookup.findVirtual(owner, name, signature);
+ case MethodHandleInfo.REF_invokeSpecial:
+ return lookup.findSpecial(owner, name, signature, owner);
+ default:
+ throw new IllegalArgumentException();
+ }
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+
private <A> SerializedLambda writeReplace(A f1) {
try {
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
@@ -189,5 +233,7 @@ class LambdaHost {
}
interface I {
- default String i() { return "i"; };
+ default String i() {
+ return "i";
+ }
}