summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAdriaan Moors <adriaan@lightbend.com>2016-08-12 16:24:47 -0700
committerGitHub <noreply@github.com>2016-08-12 16:24:47 -0700
commit3e0b2c2b14bdc26a40887af7a375077565f004b3 (patch)
tree9886fbcfc6edc3ec069fdf2994cfc1694e4640c2 /test
parent618d42c747955a43557655bdc0c4281fec5a7923 (diff)
parent131402fd5fe8c064ef5cfffbe568507cbdf37990 (diff)
downloadscala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.gz
scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.tar.bz2
scala-3e0b2c2b14bdc26a40887af7a375077565f004b3.zip
Merge pull request #5321 from retronym/topic/lock-down-deserialize
SD-193 Lock down lambda deserialization
Diffstat (limited to 'test')
-rw-r--r--test/files/run/lambda-serialization-security.scala47
-rw-r--r--test/files/run/lambda-serialization.scala71
-rw-r--r--test/junit/scala/runtime/LambdaDeserializerTest.java64
3 files changed, 145 insertions, 37 deletions
diff --git a/test/files/run/lambda-serialization-security.scala b/test/files/run/lambda-serialization-security.scala
new file mode 100644
index 0000000000..08e235b1cb
--- /dev/null
+++ b/test/files/run/lambda-serialization-security.scala
@@ -0,0 +1,47 @@
+import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+
+trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+
+object Test {
+ def main(args: Array[String]): Unit = {
+ roundTrip()
+ roundTripIndySam()
+ }
+
+ def roundTrip(): Unit = {
+ val c = new Capture("Capture")
+ val lambda = (p: Param) => ("a", p, c)
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ val p = new Param
+ assert(reconstituted1.apply(p) == ("a", p, c))
+ val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+
+ val reconstituted3 = serializeDeserialize(reconstituted1)
+ assert(reconstituted3.apply(p) == ("a", p, c))
+
+ val specializedLambda = (p: Int) => List(p, c).length
+ assert(serializeDeserialize(specializedLambda).apply(42) == 2)
+ assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ }
+
+ // lambda targeting a SAM, not a FunctionN (should behave the same way)
+ def roundTripIndySam(): Unit = {
+ val lambda: IntToString = (x: Int) => "yo!" * x
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
+ val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
+ assert(reconstituted1.apply(2) == "yo!yo!")
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+ }
+
+ def serializeDeserialize[T <: AnyRef](obj: T) = {
+ val buffer = new ByteArrayOutputStream
+ val out = new ObjectOutputStream(buffer)
+ out.writeObject(obj)
+ val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ in.readObject.asInstanceOf[T]
+ }
+}
+
+case class Capture(s: String) extends Serializable
+class Param
diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala
index 08e235b1cb..78b4c5d58b 100644
--- a/test/files/run/lambda-serialization.scala
+++ b/test/files/run/lambda-serialization.scala
@@ -1,37 +1,54 @@
-import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
+
+import scala.tools.nsc.util
+
+class C extends java.io.Serializable {
+ val fs = List(
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
+ )
+ private def foo(): Unit = {
+ assert(false, "should not be called!!!")
+ }
+}
-trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+trait FakeSam { def apply(): Unit }
object Test {
def main(args: Array[String]): Unit = {
- roundTrip()
- roundTripIndySam()
+ allRealLambdasRoundTrip()
+ fakeLambdaFailsToDeserialize()
}
- def roundTrip(): Unit = {
- val c = new Capture("Capture")
- val lambda = (p: Param) => ("a", p, c)
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- val p = new Param
- assert(reconstituted1.apply(p) == ("a", p, c))
- val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- assert(reconstituted1.getClass == reconstituted2.getClass)
-
- val reconstituted3 = serializeDeserialize(reconstituted1)
- assert(reconstituted3.apply(p) == ("a", p, c))
-
- val specializedLambda = (p: Int) => List(p, c).length
- assert(serializeDeserialize(specializedLambda).apply(42) == 2)
- assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ def allRealLambdasRoundTrip(): Unit = {
+ new C().fs.map(x => serializeDeserialize(x).apply())
}
- // lambda targeting a SAM, not a FunctionN (should behave the same way)
- def roundTripIndySam(): Unit = {
- val lambda: IntToString = (x: Int) => "yo!" * x
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
- val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
- assert(reconstituted1.apply(2) == "yo!yo!")
- assert(reconstituted1.getClass == reconstituted2.getClass)
+ def fakeLambdaFailsToDeserialize(): Unit = {
+ val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
+ MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
+ try {
+ serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
+ assert(false)
+ } catch {
+ case ex: Exception =>
+ val stackTrace = util.stackTraceString(ex)
+ assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
+ }
}
def serializeDeserialize[T <: AnyRef](obj: T) = {
@@ -43,5 +60,3 @@ object Test {
}
}
-case class Capture(s: String) extends Serializable
-class Param
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java
index 069eb4aab6..3ed1ae1365 100644
--- a/test/junit/scala/runtime/LambdaDeserializerTest.java
+++ b/test/junit/scala/runtime/LambdaDeserializerTest.java
@@ -4,9 +4,7 @@ import org.junit.Assert;
import org.junit.Test;
import java.io.Serializable;
-import java.lang.invoke.MethodHandle;
-import java.lang.invoke.MethodHandles;
-import java.lang.invoke.SerializedLambda;
+import java.lang.invoke.*;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
@@ -85,19 +83,20 @@ public final class LambdaDeserializerTest {
public void implMethodNameChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
- checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
+ checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
}
@Test
public void implMethodSignatureChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
- checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
+ checkIllegalAccess(sl, copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
}
- private void checkIllegalAccess(SerializedLambda serialized) {
+ private void checkIllegalAccess(SerializedLambda allowed, SerializedLambda requested) {
try {
- LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
+ HashMap<String, MethodHandle> allowedMap = createAllowedMap(LambdaHost.lookup(), allowed);
+ LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, allowedMap, requested);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
@@ -123,6 +122,7 @@ public final class LambdaDeserializerTest {
throw new RuntimeException(e);
}
}
+
private <A, B> A reconstitute(A f1) {
return reconstitute(f1, null);
}
@@ -130,12 +130,56 @@ public final class LambdaDeserializerTest {
@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
- return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
+ return deserizalizeLambdaCreatingAllowedMap(f1, cache, LambdaHost.lookup());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
+ private <A> A deserizalizeLambdaCreatingAllowedMap(A f1, HashMap<String, MethodHandle> cache, MethodHandles.Lookup lookup) {
+ SerializedLambda serialized = writeReplace(f1);
+ HashMap<String, MethodHandle> allowed = createAllowedMap(lookup, serialized);
+ return (A) LambdaDeserializer.deserializeLambda(lookup, cache, allowed, serialized);
+ }
+
+ private HashMap<String, MethodHandle> createAllowedMap(MethodHandles.Lookup lookup, SerializedLambda serialized) {
+ Class<?> implClass = classForName(serialized.getImplClass().replace("/", "."), lookup.lookupClass().getClassLoader());
+ MethodHandle implMethod = findMember(lookup, serialized.getImplMethodKind(), implClass, serialized.getImplMethodName(), MethodType.fromMethodDescriptorString(serialized.getImplMethodSignature(), lookup.lookupClass().getClassLoader()));
+ HashMap<String, MethodHandle> allowed = new HashMap<>();
+ allowed.put(LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName(), serialized.getImplMethodSignature()), implMethod);
+ return allowed;
+ }
+
+ private Class<?> classForName(String className, ClassLoader classLoader) {
+ try {
+ return Class.forName(className, true, classLoader);
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private MethodHandle findMember(MethodHandles.Lookup lookup, int kind, Class<?> owner,
+ String name, MethodType signature) {
+ try {
+ switch (kind) {
+ case MethodHandleInfo.REF_invokeStatic:
+ return lookup.findStatic(owner, name, signature);
+ case MethodHandleInfo.REF_newInvokeSpecial:
+ return lookup.findConstructor(owner, signature);
+ case MethodHandleInfo.REF_invokeVirtual:
+ case MethodHandleInfo.REF_invokeInterface:
+ return lookup.findVirtual(owner, name, signature);
+ case MethodHandleInfo.REF_invokeSpecial:
+ return lookup.findSpecial(owner, name, signature, owner);
+ default:
+ throw new IllegalArgumentException();
+ }
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+
private <A> SerializedLambda writeReplace(A f1) {
try {
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
@@ -189,5 +233,7 @@ class LambdaHost {
}
interface I {
- default String i() { return "i"; };
+ default String i() {
+ return "i";
+ }
}