diff options
Diffstat (limited to 'test/junit/scala/runtime')
-rw-r--r-- | test/junit/scala/runtime/LambdaDeserializerTest.java | 240 | ||||
-rw-r--r-- | test/junit/scala/runtime/ScalaRunTimeTest.scala | 65 | ||||
-rw-r--r-- | test/junit/scala/runtime/ZippedTest.scala | 68 |
3 files changed, 312 insertions, 61 deletions
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000000..4e9c5c8954 --- /dev/null +++ b/test/junit/scala/runtime/LambdaDeserializerTest.java @@ -0,0 +1,240 @@ +package scala.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.*; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationStatic() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationVirtualMethodReference() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationInterfaceMethodReference() { + F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference(); + I i = new I() { + }; + Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i)); + } + + @Test + public void serializationStaticMethodReference() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationNewInvokeSpecial() { + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + Assert.assertEquals(f1.apply(), reconstitute(f1).apply()); + } + + @Test + public void uncached() { + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0<Object> reconstituted1 = reconstitute(f1); + F0<Object> reconstituted2 = reconstitute(f1); + Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cached() { + HashMap<String, MethodHandle> cache = new HashMap<>(); + F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0<Object> reconstituted1 = reconstitute(f1, cache); + F0<Object> reconstituted2 = reconstitute(f1, cache); + Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cachedStatic() { + HashMap<String, MethodHandle> cache = new HashMap<>(); + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + // Check that deserialization of a static lambda always returns the + // same instance. + Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache)); + + // (as is the case with regular invocation.) + Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod()); + } + + @Test + public void implMethodNameChanged() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) { + try { + HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed); + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class<?> loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private <A, B> A reconstitute(A f1) { + return reconstitute(f1, null); + } + + @SuppressWarnings("unchecked") + private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) { + try { + return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) { + SerializedLambda serialized = writeReplace(f1); + HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized); + return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized); + } + + private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) { + Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader()); + MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader())); + HashMap<String, MethodHandle> allowed = new HashMap<>(); + allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod); + return allowed; + } + + private Class<?> classForName(String className, ClassLoader classLoader) { + try { + return Class.forName(className, true, classLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner, + String name, MethodType signature) { + try { + switch (kind) { + case MethodHandleInfo.REF_invokeStatic: + return lookup.findStatic(owner, name, signature); + case MethodHandleInfo.REF_newInvokeSpecial: + return lookup.findConstructor(owner, signature); + case MethodHandleInfo.REF_invokeVirtual: + case MethodHandleInfo.REF_invokeInterface: + return lookup.findVirtual(owner, name, signature); + case MethodHandleInfo.REF_invokeSpecial: + return lookup.findSpecial(owner, name, signature, owner); + default: + throw new IllegalArgumentException(); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + + private <A> SerializedLambda writeReplace(A f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1<A, B> extends Serializable { + B apply(A a); +} + +interface F0<A> extends Serializable { + A apply(); +} + +class LambdaHost { + public F1<Boolean, String> lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + @SuppressWarnings("Convert2MethodRef") + public F1<Boolean, String> lambdaBackedByStaticImplMethod() { + return (b) -> String.valueOf(b); + } + + public F1<Boolean, String> lambdaBackedByStaticMethodReference() { + return String::valueOf; + } + + public F1<Boolean, String> lambdaBackedByVirtualMethodReference() { + return Object::toString; + } + + public F1<I, Object> lambdaBackedByInterfaceMethodReference() { + return I::i; + } + + public F0<Object> lambdaBackedByConstructorCall() { + return String::new; + } + + public static MethodHandles.Lookup lookup() { + return MethodHandles.lookup(); + } +} + +interface I { + default String i() { + return "i"; + } +} diff --git a/test/junit/scala/runtime/ScalaRunTimeTest.scala b/test/junit/scala/runtime/ScalaRunTimeTest.scala index e28deae786..ba3bf0b703 100644 --- a/test/junit/scala/runtime/ScalaRunTimeTest.scala +++ b/test/junit/scala/runtime/ScalaRunTimeTest.scala @@ -5,70 +5,10 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -/** Tests for the private class DefaultPromise */ +/** Tests for the runtime object ScalaRunTime */ @RunWith(classOf[JUnit4]) class ScalaRunTimeTest { @Test - def testIsTuple() { - import ScalaRunTime.isTuple - def check(v: Any) = { - assertTrue(v.toString, isTuple(v)) - } - - val s = "" - check(Tuple1(s)) - check((s, s)) - check((s, s, s)) - check((s, s, s, s)) - check((s, s, s, s, s)) - check((s, s, s, s, s, s)) - check((s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - check((s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s)) - - // some specialized variants will have mangled classnames - check(Tuple1(0)) - check((0, 0)) - check((0, 0, 0)) - check((0, 0, 0, 0)) - check((0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - check((0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) - - case class C() - val c = new C() - assertFalse(c.toString, isTuple(c)) - } - - @Test def testStringOf() { import ScalaRunTime.stringOf import scala.collection._ @@ -109,14 +49,17 @@ class ScalaRunTimeTest { val tuple1 = Tuple1(0) assertEquals("(0,)", stringOf(tuple1)) assertEquals("(0,)", stringOf(tuple1, 0)) + assertEquals("(Array(0),)", stringOf(Tuple1(Array(0)))) val tuple2 = Tuple2(0, 1) assertEquals("(0,1)", stringOf(tuple2)) assertEquals("(0,1)", stringOf(tuple2, 0)) + assertEquals("(Array(0),1)", stringOf((Array(0), 1))) val tuple3 = Tuple3(0, 1, 2) assertEquals("(0,1,2)", stringOf(tuple3)) assertEquals("(0,1,2)", stringOf(tuple3, 0)) + assertEquals("(Array(0),1,2)", stringOf((Array(0), 1, 2))) val x = new Object { override def toString(): String = "this is the stringOf string" diff --git a/test/junit/scala/runtime/ZippedTest.scala b/test/junit/scala/runtime/ZippedTest.scala new file mode 100644 index 0000000000..d3ce4945aa --- /dev/null +++ b/test/junit/scala/runtime/ZippedTest.scala @@ -0,0 +1,68 @@ + +package scala.runtime + +import scala.language.postfixOps + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +/** Tests Tuple?Zipped */ +@RunWith(classOf[JUnit4]) +class ZippedTest { + @Test + def crossZipped() { + + val xs1 = List.range(1, 100) + val xs2 = xs1.view + val xs3 = xs1 take 10 + val ss1 = Stream from 1 + val ss2 = ss1.view + val ss3 = ss1 take 10 + val as1 = 1 to 100 toArray + val as2 = as1.view + val as3 = as1 take 10 + + def xss1 = List[Seq[Int]](xs1, xs2, xs3, ss1, ss2, ss3, as1, as2, as3) + def xss2 = List[Seq[Int]](xs1, xs2, xs3, ss3, as1, as2, as3) // no infinities + def xss3 = List[Seq[Int]](xs2, xs3, ss3, as1) // representative sampling + + for (cc1 <- xss1 ; cc2 <- xss2) { + val sum1 = (cc1, cc2).zipped map { case (x, y) => x + y } sum + val sum2 = (cc1, cc2).zipped map (_ + _) sum + + assert(sum1 == sum2) + } + + for (cc1 <- xss1 ; cc2 <- xss2 ; cc3 <- xss3) { + val sum1 = (cc1, cc2, cc3).zipped map { case (x, y, z) => x + y + z } sum + val sum2 = (cc1, cc2, cc3).zipped map (_ + _ + _) sum + + assert(sum1 == sum2) + } + + assert((ss1, ss1).zipped exists ((x, y) => true)) + assert((ss1, ss1, ss1).zipped exists ((x, y, z) => true)) + + assert(!(ss1, ss2, 1 to 3).zipped.exists(_ + _ + _ > 100000)) + assert((1 to 3, ss1, ss2).zipped.forall(_ + _ + _ > 0)) + assert((ss1, 1 to 3, ss2).zipped.map(_ + _ + _).size == 3) + } + + @Test + def test_si9379() { + class Boom { + private var i = -1 + def inc = { + i += 1 + if (i > 1000) throw new NoSuchElementException("Boom! Too many elements!") + i + } + } + val b = new Boom + val s = Stream.continually(b.inc) + // zipped.toString must allow s to short-circuit evaluation + assertTrue((s, s).zipped.toString contains s.toString) + } +} |