summaryrefslogblamecommitdiff
path: root/test/junit/scala/runtime/LambdaDeserializerTest.java
blob: 4e9c5c895488071412dc2e00b4e0bd22d81022c4 (plain) (tree)
1
2
3
4
5
6
7





                            
                          













































































                                                                                     
                                                                                                                      





                                                                             
                                                                                                                                            

     
                                                                                           
             

                                                                                                      
























                                                                                                                                               
 






                                                                                        
                                                                                        




                                          
                                  











































                                                                                                                                                                                                                                                   




















































                                                                                             


                        
 
package scala.runtime;

import org.junit.Assert;
import org.junit.Test;

import java.io.Serializable;
import java.lang.invoke.*;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;

public final class LambdaDeserializerTest {
    private LambdaHost lambdaHost = new LambdaHost();

    @Test
    public void serializationPrivate() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
        Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
    }

    @Test
    public void serializationStatic() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
        Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
    }

    @Test
    public void serializationVirtualMethodReference() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
        Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
    }

    @Test
    public void serializationInterfaceMethodReference() {
        F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
        I i = new I() {
        };
        Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
    }

    @Test
    public void serializationStaticMethodReference() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
        Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
    }

    @Test
    public void serializationNewInvokeSpecial() {
        F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
        Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
    }

    @Test
    public void uncached() {
        F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
        F0<Object> reconstituted1 = reconstitute(f1);
        F0<Object> reconstituted2 = reconstitute(f1);
        Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass());
    }

    @Test
    public void cached() {
        HashMap<String, MethodHandle> cache = new HashMap<>();
        F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
        F0<Object> reconstituted1 = reconstitute(f1, cache);
        F0<Object> reconstituted2 = reconstitute(f1, cache);
        Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass());
    }

    @Test
    public void cachedStatic() {
        HashMap<String, MethodHandle> cache = new HashMap<>();
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
        // Check that deserialization of a static lambda always returns the
        // same instance.
        Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache));

        // (as is the case with regular invocation.)
        Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod());
    }

    @Test
    public void implMethodNameChanged() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
        SerializedLambda sl = writeReplace(f1);
        checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
    }

    @Test
    public void implMethodSignatureChanged() {
        F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
        SerializedLambda sl = writeReplace(f1);
        checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
    }

    private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) {
        try {
            HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed);
            LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested);
            throw new AssertionError();
        } catch (IllegalArgumentException iae) {
            if (!iae.getMessage().contains("Illegal lambda deserialization")) {
                Assert.fail("Unexpected message: " + iae.getMessage());
            }
        }
    }

    private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
        Object[] captures = new Object[sl.getCapturedArgCount()];
        for (int i = 0; i < captures.length; i++) {
            captures[i] = sl.getCapturedArg(i);
        }
        return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
                sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
                sl.getInstantiatedMethodType(), captures);
    }

    private Class<?> loadClass(String className) {
        try {
            return Class.forName(className.replace('/', '.'));
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private <A, B> A reconstitute(A f1) {
        return reconstitute(f1, null);
    }

    @SuppressWarnings("unchecked")
    private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
        try {
            return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @SuppressWarnings("unchecked")
    private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) {
        SerializedLambda serialized = writeReplace(f1);
        HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized);
        return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized);
    }

    private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) {
        Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader());
        MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader()));
        HashMap<String, MethodHandle> allowed = new HashMap<>();
        allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod);
        return allowed;
    }

    private Class<?> classForName(String className, ClassLoader classLoader) {
        try {
            return Class.forName(className, true, classLoader);
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner,
                                    String name, MethodType signature) {
        try {
            switch (kind) {
                case MethodHandleInfo.REF_invokeStatic:
                    return lookup.findStatic(owner, name, signature);
                case MethodHandleInfo.REF_newInvokeSpecial:
                    return lookup.findConstructor(owner, signature);
                case MethodHandleInfo.REF_invokeVirtual:
                case MethodHandleInfo.REF_invokeInterface:
                    return lookup.findVirtual(owner, name, signature);
                case MethodHandleInfo.REF_invokeSpecial:
                    return lookup.findSpecial(owner, name, signature, owner);
                default:
                    throw new IllegalArgumentException();
            }
        } catch (NoSuchMethodException | IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }


    private <A> SerializedLambda writeReplace(A f1) {
        try {
            Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
            writeReplace.setAccessible(true);
            return (SerializedLambda) writeReplace.invoke(f1);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}


interface F1<A, B> extends Serializable {
    B apply(A a);
}

interface F0<A> extends Serializable {
    A apply();
}

class LambdaHost {
    public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
        int local = 42;
        return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
    }

    @SuppressWarnings("Convert2MethodRef")
    public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
        return (b) -> String.valueOf(b);
    }

    public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
        return String::valueOf;
    }

    public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
        return Object::toString;
    }

    public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
        return I::i;
    }

    public F0<Object> lambdaBackedByConstructorCall() {
        return String::new;
    }

    public static MethodHandles.Lookup lookup() {
        return MethodHandles.lookup();
    }
}

interface I {
    default String i() {
        return "i";
    }
}