diff options
Diffstat (limited to 'test/junit/scala/runtime/LambdaDeserializerTest.java')
-rw-r--r-- | test/junit/scala/runtime/LambdaDeserializerTest.java | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000000..4e9c5c8954 --- /dev/null +++ b/test/junit/scala/runtime/LambdaDeserializerTest.java @@ -0,0 +1,240 @@ +package scala.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.*; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationStatic() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationVirtualMethodReference() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationInterfaceMethodReference() { + F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference(); + I i = new I() { + }; + Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i)); + } + + @Test + public void serializationStaticMethodReference() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationNewInvokeSpecial() { + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + Assert.assertEquals(f1.apply(), reconstitute(f1).apply()); + } + + @Test + public void uncached() { + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0<Object> reconstituted1 = reconstitute(f1); + F0<Object> reconstituted2 = reconstitute(f1); + Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cached() { + HashMap<String, MethodHandle> cache = new HashMap<>(); + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0<Object> reconstituted1 = reconstitute(f1, cache); + F0<Object> reconstituted2 = reconstitute(f1, cache); + Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cachedStatic() { + HashMap<String, MethodHandle> cache = new HashMap<>(); + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + // Check that deserialization of a static lambda always returns the + // same instance. + Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache)); + + // (as is the case with regular invocation.) + Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod()); + } + + @Test + public void implMethodNameChanged() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) { + try { + HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed); + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class<?> loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private <A, B> A reconstitute(A f1) { + return reconstitute(f1, null); + } + + @SuppressWarnings("unchecked") + private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) { + try { + return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) { + SerializedLambda serialized = writeReplace(f1); + HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized); + return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized); + } + + private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) { + Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader()); + MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader())); + HashMap<String, MethodHandle> allowed = new HashMap<>(); + allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod); + return allowed; + } + + private Class<?> classForName(String className, ClassLoader classLoader) { + try { + return Class.forName(className, true, classLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner, + String name, MethodType signature) { + try { + switch (kind) { + case MethodHandleInfo.REF_invokeStatic: + return lookup.findStatic(owner, name, signature); + case MethodHandleInfo.REF_newInvokeSpecial: + return lookup.findConstructor(owner, signature); + case MethodHandleInfo.REF_invokeVirtual: + case MethodHandleInfo.REF_invokeInterface: + return lookup.findVirtual(owner, name, signature); + case MethodHandleInfo.REF_invokeSpecial: + return lookup.findSpecial(owner, name, signature, owner); + default: + throw new IllegalArgumentException(); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + + private <A> SerializedLambda writeReplace(A f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1<A, B> extends Serializable { + B apply(A a); +} + +interface F0<A> extends Serializable { + A apply(); +} + +class LambdaHost { + public F1<Boolean, String> lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + @SuppressWarnings("Convert2MethodRef") + public F1<Boolean, String> lambdaBackedByStaticImplMethod() { + return (b) -> String.valueOf(b); + } + + public F1<Boolean, String> lambdaBackedByStaticMethodReference() { + return String::valueOf; + } + + public F1<Boolean, String> lambdaBackedByVirtualMethodReference() { + return Object::toString; + } + + public F1<I, Object> lambdaBackedByInterfaceMethodReference() { + return I::i; + } + + public F0<Object> lambdaBackedByConstructorCall() { + return String::new; + } + + public static MethodHandles.Lookup lookup() { + return MethodHandles.lookup(); + } +} + +interface I { + default String i() { + return "i"; + } +} |