diff options
author | Adriaan Moors <adriaan@lightbend.com> | 2016-08-12 16:24:47 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-12 16:24:47 -0700 |
commit | 3e0b2c2b14bdc26a40887af7a375077565f004b3 (patch) | |
tree | 9886fbcfc6edc3ec069fdf2994cfc1694e4640c2 /test | |
parent | 618d42c747955a43557655bdc0c4281fec5a7923 (diff) | |
parent | 131402fd5fe8c064ef5cfffbe568507cbdf37990 (diff) | |
download | scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.gz scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.bz2 scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.zip |
Merge pull request #5321 from retronym/topic/lock-down-deserialize
SD-193 Lock down lambda deserialization
Diffstat (limited to 'test')
-rw-r--r-- | test/files/run/lambda-serialization-security.scala | 47 | ||||
-rw-r--r-- | test/files/run/lambda-serialization.scala | 71 | ||||
-rw-r--r-- | test/junit/scala/runtime/LambdaDeserializerTest.java | 64 |
3 files changed, 145 insertions, 37 deletions
diff --git a/test/files/run/lambda-serialization-security.scala b/test/files/run/lambda-serialization-security.scala new file mode 100644 index 0000000000..08e235b1cb --- /dev/null +++ b/test/files/run/lambda-serialization-security.scala @@ -0,0 +1,47 @@ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream} + +trait IntToString extends java.io.Serializable { def apply(i: Int): String } + +object Test { + def main(args: Array[String]): Unit = { + roundTrip() + roundTripIndySam() + } + + def roundTrip(): Unit = { + val c = new Capture("Capture") + val lambda = (p: Param) => ("a", p, c) + val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + val p = new Param + assert(reconstituted1.apply(p) == ("a", p, c)) + val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + assert(reconstituted1.getClass == reconstituted2.getClass) + + val reconstituted3 = serializeDeserialize(reconstituted1) + assert(reconstituted3.apply(p) == ("a", p, c)) + + val specializedLambda = (p: Int) => List(p, c).length + assert(serializeDeserialize(specializedLambda).apply(42) == 2) + assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2) + } + + // lambda targeting a SAM, not a FunctionN (should behave the same way) + def roundTripIndySam(): Unit = { + val lambda: IntToString = (x: Int) => "yo!" * x + val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString] + val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString] + assert(reconstituted1.apply(2) == "yo!yo!") + assert(reconstituted1.getClass == reconstituted2.getClass) + } + + def serializeDeserialize[T <: AnyRef](obj: T) = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } +} + +case class Capture(s: String) extends Serializable +class Param diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala index 08e235b1cb..78b4c5d58b 100644 --- a/test/files/run/lambda-serialization.scala +++ b/test/files/run/lambda-serialization.scala @@ -1,37 +1,54 @@ -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.lang.invoke.{MethodHandleInfo, SerializedLambda} + +import scala.tools.nsc.util + +class C extends java.io.Serializable { + val fs = List( + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => () + ) + private def foo(): Unit = { + assert(false, "should not be called!!!") + } +} -trait IntToString extends java.io.Serializable { def apply(i: Int): String } +trait FakeSam { def apply(): Unit } object Test { def main(args: Array[String]): Unit = { - roundTrip() - roundTripIndySam() + allRealLambdasRoundTrip() + fakeLambdaFailsToDeserialize() } - def roundTrip(): Unit = { - val c = new Capture("Capture") - val lambda = (p: Param) => ("a", p, c) - val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any] - val p = new Param - assert(reconstituted1.apply(p) == ("a", p, c)) - val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any] - assert(reconstituted1.getClass == reconstituted2.getClass) - - val reconstituted3 = serializeDeserialize(reconstituted1) - assert(reconstituted3.apply(p) == ("a", p, c)) - - val specializedLambda = (p: Int) => List(p, c).length - assert(serializeDeserialize(specializedLambda).apply(42) == 2) - assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2) + def allRealLambdasRoundTrip(): Unit = { + new C().fs.map(x => serializeDeserialize(x).apply()) } - // lambda targeting a SAM, not a FunctionN (should behave the same way) - def roundTripIndySam(): Unit = { - val lambda: IntToString = (x: Int) => "yo!" * x - val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString] - val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString] - assert(reconstituted1.apply(2) == "yo!yo!") - assert(reconstituted1.getClass == reconstituted2.getClass) + def fakeLambdaFailsToDeserialize(): Unit = { + val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V", + MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C)) + try { + serializeDeserialize(fake).asInstanceOf[FakeSam].apply() + assert(false) + } catch { + case ex: Exception => + val stackTrace = util.stackTraceString(ex) + assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace) + } } def serializeDeserialize[T <: AnyRef](obj: T) = { @@ -43,5 +60,3 @@ object Test { } } -case class Capture(s: String) extends Serializable -class Param diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java index 069eb4aab6..3ed1ae1365 100644 --- a/test/junit/scala/runtime/LambdaDeserializerTest.java +++ b/test/junit/scala/runtime/LambdaDeserializerTest.java @@ -4,9 +4,7 @@ import org.junit.Assert; import org.junit.Test; import java.io.Serializable; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.SerializedLambda; +import java.lang.invoke.*; import java.lang.reflect.Method; import java.util.Arrays; import java.util.HashMap; @@ -85,19 +83,20 @@ public final class LambdaDeserializerTest { public void implMethodNameChanged() { F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); - checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); } @Test public void implMethodSignatureChanged() { F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod(); SerializedLambda sl = writeReplace(f1); - checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); } - private void checkIllegalAccess(SerializedLambda serialized) { + private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) { try { - LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized); + HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed); + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested); throw new AssertionError(); } catch (IllegalArgumentException iae) { if (!iae.getMessage().contains("Illegal lambda deserialization")) { @@ -123,6 +122,7 @@ public final class LambdaDeserializerTest { throw new RuntimeException(e); } } + private <A, B> A reconstitute(A f1) { return reconstitute(f1, null); } @@ -130,12 +130,56 @@ public final class LambdaDeserializerTest { @SuppressWarnings("unchecked") private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) { try { - return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1)); + return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup()); } catch (Exception e) { throw new RuntimeException(e); } } + private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) { + SerializedLambda serialized = writeReplace(f1); + HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized); + return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized); + } + + private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) { + Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader()); + MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader())); + HashMap<String, MethodHandle> allowed = new HashMap<>(); + allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod); + return allowed; + } + + private Class<?> classForName(String className, ClassLoader classLoader) { + try { + return Class.forName(className, true, classLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + + private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner, + String name, MethodType signature) { + try { + switch (kind) { + case MethodHandleInfo.REF_invokeStatic: + return lookup.findStatic(owner, name, signature); + case MethodHandleInfo.REF_newInvokeSpecial: + return lookup.findConstructor(owner, signature); + case MethodHandleInfo.REF_invokeVirtual: + case MethodHandleInfo.REF_invokeInterface: + return lookup.findVirtual(owner, name, signature); + case MethodHandleInfo.REF_invokeSpecial: + return lookup.findSpecial(owner, name, signature, owner); + default: + throw new IllegalArgumentException(); + } + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private <A> SerializedLambda writeReplace(A f1) { try { Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); @@ -189,5 +233,7 @@ class LambdaHost { } interface I { - default String i() { return "i"; }; + default String i() { + return "i"; + } } |