diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2016-08-04 01:58:33 -0700 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2016-08-08 13:42:36 +1000 |
commit | 498a2ce7397b909c0bebf36affeb1ee5a1c03d6a (patch) | |
tree | abcfcf39af3231ade6402d30acca548704711311 /test | |
parent | 2b172be8c83c3146d3fd5ab01546c171ab18fa46 (diff) | |
download | scala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.tar.gz scala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.tar.bz2 scala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.zip |
SD-193 Lock down lambda deserialization
The old design allowed a forged `SerializedLambda` to be
deserialized into a lambda that could call any private method
in the host class.
This commit passes through the list of all lambda impl methods
to the bootstrap method and verifies that you are deserializing
one of these.
The new test case shows that a forged lambda can no longer call
the private method, and that the new encoding is okay with a large
number of lambdas in a file.
We already have method handle constants in the constant pool
to support the invokedynamic through LambdaMetafactory, so
the only additional cost will be referring to these in
the boostrap args for `LambdaDeserialize`, 2 bytes per lambda.
I checked this with an example:
https://gist.github.com/retronym/e343d211f7536d06f1fef4b499a0a177
Fixes SD-193
Diffstat (limited to 'test')
-rw-r--r-- | test/files/run/lambda-serialization-security.scala | 47 | ||||
-rw-r--r-- | test/files/run/lambda-serialization.scala | 71 | ||||
-rw-r--r-- | test/junit/scala/runtime/LambdaDeserializerTest.java | 4 |
3 files changed, 92 insertions, 30 deletions
diff --git a/test/files/run/lambda-serialization-security.scala b/test/files/run/lambda-serialization-security.scala new file mode 100644 index 0000000000..08e235b1cb --- /dev/null +++ b/test/files/run/lambda-serialization-security.scala @@ -0,0 +1,47 @@ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream} + +trait IntToString extends java.io.Serializable { def apply(i: Int): String } + +object Test { + def main(args: Array[String]): Unit = { + roundTrip() + roundTripIndySam() + } + + def roundTrip(): Unit = { + val c = new Capture("Capture") + val lambda = (p: Param) => ("a", p, c) + val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + val p = new Param + assert(reconstituted1.apply(p) == ("a", p, c)) + val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + assert(reconstituted1.getClass == reconstituted2.getClass) + + val reconstituted3 = serializeDeserialize(reconstituted1) + assert(reconstituted3.apply(p) == ("a", p, c)) + + val specializedLambda = (p: Int) => List(p, c).length + assert(serializeDeserialize(specializedLambda).apply(42) == 2) + assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2) + } + + // lambda targeting a SAM, not a FunctionN (should behave the same way) + def roundTripIndySam(): Unit = { + val lambda: IntToString = (x: Int) => "yo!" * x + val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString] + val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString] + assert(reconstituted1.apply(2) == "yo!yo!") + assert(reconstituted1.getClass == reconstituted2.getClass) + } + + def serializeDeserialize[T <: AnyRef](obj: T) = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } +} + +case class Capture(s: String) extends Serializable +class Param diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala index 08e235b1cb..78b4c5d58b 100644 --- a/test/files/run/lambda-serialization.scala +++ b/test/files/run/lambda-serialization.scala @@ -1,37 +1,54 @@ -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.lang.invoke.{MethodHandleInfo, SerializedLambda} + +import scala.tools.nsc.util + +class C extends java.io.Serializable { + val fs = List( + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (), + () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => () + ) + private def foo(): Unit = { + assert(false, "should not be called!!!") + } +} -trait IntToString extends java.io.Serializable { def apply(i: Int): String } +trait FakeSam { def apply(): Unit } object Test { def main(args: Array[String]): Unit = { - roundTrip() - roundTripIndySam() + allRealLambdasRoundTrip() + fakeLambdaFailsToDeserialize() } - def roundTrip(): Unit = { - val c = new Capture("Capture") - val lambda = (p: Param) => ("a", p, c) - val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any] - val p = new Param - assert(reconstituted1.apply(p) == ("a", p, c)) - val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any] - assert(reconstituted1.getClass == reconstituted2.getClass) - - val reconstituted3 = serializeDeserialize(reconstituted1) - assert(reconstituted3.apply(p) == ("a", p, c)) - - val specializedLambda = (p: Int) => List(p, c).length - assert(serializeDeserialize(specializedLambda).apply(42) == 2) - assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2) + def allRealLambdasRoundTrip(): Unit = { + new C().fs.map(x => serializeDeserialize(x).apply()) } - // lambda targeting a SAM, not a FunctionN (should behave the same way) - def roundTripIndySam(): Unit = { - val lambda: IntToString = (x: Int) => "yo!" * x - val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString] - val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString] - assert(reconstituted1.apply(2) == "yo!yo!") - assert(reconstituted1.getClass == reconstituted2.getClass) + def fakeLambdaFailsToDeserialize(): Unit = { + val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V", + MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C)) + try { + serializeDeserialize(fake).asInstanceOf[FakeSam].apply() + assert(false) + } catch { + case ex: Exception => + val stackTrace = util.stackTraceString(ex) + assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace) + } } def serializeDeserialize[T <: AnyRef](obj: T) = { @@ -43,5 +60,3 @@ object Test { } } -case class Capture(s: String) extends Serializable -class Param diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java index 069eb4aab6..ba52e979cc 100644 --- a/test/junit/scala/runtime/LambdaDeserializerTest.java +++ b/test/junit/scala/runtime/LambdaDeserializerTest.java @@ -97,7 +97,7 @@ public final class LambdaDeserializerTest { private void checkIllegalAccess(SerializedLambda serialized) { try { - LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized); + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, null, serialized); throw new AssertionError(); } catch (IllegalArgumentException iae) { if (!iae.getMessage().contains("Illegal lambda deserialization")) { @@ -130,7 +130,7 @@ public final class LambdaDeserializerTest { @SuppressWarnings("unchecked") private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) { try { - return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1)); + return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, null, writeReplace(f1)); } catch (Exception e) { throw new RuntimeException(e); } |