package scala.runtime; import org.junit.Assert; import org.junit.Test; import java.io.Serializable; import java.lang.invoke.*; import java.lang.reflect.Method; import java.util.Arrays; import java.util.HashMap; public final class LambdaDeserializerTest { private LambdaHost lambdaHost = new LambdaHost(); @Test public void serializationPrivate() { F1 f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); } @Test public void serializationStatic() { F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); } @Test public void serializationVirtualMethodReference() { F1 f1 = lambdaHost.lambdaBackedByVirtualMethodReference(); Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); } @Test public void serializationInterfaceMethodReference() { F1 f1 = lambdaHost.lambdaBackedByInterfaceMethodReference(); I i = new I() { }; Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i)); } @Test public void serializationStaticMethodReference() { F1 f1 = lambdaHost.lambdaBackedByStaticMethodReference(); Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); } @Test public void serializationNewInvokeSpecial() { F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); Assert.assertEquals(f1.apply(), reconstitute(f1).apply()); } @Test public void uncached() { F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); F0 reconstituted1 = reconstitute(f1); F0 reconstituted2 = reconstitute(f1); Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass()); } @Test public void cached() { HashMap cache = new HashMap<>(); F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); F0 reconstituted1 = reconstitute(f1, cache); F0 reconstituted2 = reconstitute(f1, cache); Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); } @Test public void cachedStatic() { HashMap cache = new HashMap<>(); F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); // Check that deserialization of a static lambda always returns the // same instance. Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache)); // (as is the case with regular invocation.) Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod()); } @Test public void implMethodNameChanged() { F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); } @Test public void implMethodSignatureChanged() { F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); } private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) { try { HashMap allowedMap = createAllowedMap(LambdaHost.lookup(), allowed); LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested); throw new AssertionError(); } catch (IllegalArgumentException iae) { if (!iae.getMessage().contains("Illegal lambda deserialization")) { Assert.fail("Unexpected message: " + iae.getMessage()); } } } private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { Object[] captures = new Object[sl.getCapturedArgCount()]; for (int i = 0; i < captures.length; i++) { captures[i] = sl.getCapturedArg(i); } return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, sl.getInstantiatedMethodType(), captures); } private Class loadClass(String className) { try { return Class.forName(className.replace('/', '.')); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } } private A reconstitute(A f1) { return reconstitute(f1, null); } @SuppressWarnings("unchecked") private A reconstitute(A f1, java.util.HashMap cache) { try { return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup()); } catch (Exception e) { throw new RuntimeException(e); } } @SuppressWarnings("unchecked") private A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap cache, MethodHandles.Lookup lookup) { SerializedLambda serialized = writeReplace(f1); HashMap allowed = createAllowedMap(lookup, serialized); return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized); } private HashMap createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) { Class implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader()); MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader())); HashMap allowed = new HashMap<>(); allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod); return allowed; } private Class classForName(String className, ClassLoader classLoader) { try { return Class.forName(className, true, classLoader); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } } private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class owner, String name, MethodType signature) { try { switch (kind) { case MethodHandleInfo.REF_invokeStatic: return lookup.findStatic(owner, name, signature); case MethodHandleInfo.REF_newInvokeSpecial: return lookup.findConstructor(owner, signature); case MethodHandleInfo.REF_invokeVirtual: case MethodHandleInfo.REF_invokeInterface: return lookup.findVirtual(owner, name, signature); case MethodHandleInfo.REF_invokeSpecial: return lookup.findSpecial(owner, name, signature, owner); default: throw new IllegalArgumentException(); } } catch (NoSuchMethodException | IllegalAccessException e) { throw new RuntimeException(e); } } private SerializedLambda writeReplace(A f1) { try { Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); writeReplace.setAccessible(true); return (SerializedLambda) writeReplace.invoke(f1); } catch (Exception e) { throw new RuntimeException(e); } } } interface F1 extends Serializable { B apply(A a); } interface F0 extends Serializable { A apply(); } class LambdaHost { public F1 lambdaBackedByPrivateImplMethod() { int local = 42; return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); } @SuppressWarnings("Convert2MethodRef") public F1 lambdaBackedByStaticImplMethod() { return (b) -> String.valueOf(b); } public F1 lambdaBackedByStaticMethodReference() { return String::valueOf; } public F1 lambdaBackedByVirtualMethodReference() { return Object::toString; } public F1 lambdaBackedByInterfaceMethodReference() { return I::i; } public F0 lambdaBackedByConstructorCall() { return String::new; } public static MethodHandles.Lookup lookup() { return MethodHandles.lookup(); } } interface I { default String i() { return "i"; } }