summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2016-08-04 01:58:33 -0700
committerJason Zaugg <jzaugg@gmail.com>2016-08-08 13:42:36 +1000
commit498a2ce7397b909c0bebf36affeb1ee5a1c03d6a (patch)
treeabcfcf39af3231ade6402d30acca548704711311
parent2b172be8c83c3146d3fd5ab01546c171ab18fa46 (diff)
downloadscala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.tar.gz
scala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.tar.bz2
scala-498a2ce7397b909c0bebf36affeb1ee5a1c03d6a.zip
SD-193 Lock down lambda deserialization
The old design allowed a forged `SerializedLambda` to be deserialized into a lambda that could call any private method in the host class. This commit passes through the list of all lambda impl methods to the bootstrap method and verifies that you are deserializing one of these. The new test case shows that a forged lambda can no longer call the private method, and that the new encoding is okay with a large number of lambdas in a file. We already have method handle constants in the constant pool to support the invokedynamic through LambdaMetafactory, so the only additional cost will be referring to these in the boostrap args for `LambdaDeserialize`, 2 bytes per lambda. I checked this with an example: https://gist.github.com/retronym/e343d211f7536d06f1fef4b499a0a177 Fixes SD-193
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala3
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala8
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala11
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala16
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala3
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala17
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala7
-rw-r--r--src/library/scala/runtime/LambdaDeserialize.java25
-rw-r--r--src/library/scala/runtime/LambdaDeserializer.scala15
-rw-r--r--test/files/run/lambda-serialization-security.scala47
-rw-r--r--test/files/run/lambda-serialization.scala71
-rw-r--r--test/junit/scala/runtime/LambdaDeserializerTest.java4
12 files changed, 164 insertions, 63 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
index d5c4b5e201..6f9682f434 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
@@ -14,6 +14,7 @@ import scala.reflect.internal.Flags
import scala.tools.asm
import GenBCode._
import BackendReporting._
+import scala.collection.mutable
import scala.tools.asm.Opcodes
import scala.tools.asm.tree.{MethodInsnNode, MethodNode}
import scala.tools.nsc.backend.jvm.BCodeHelpers.{InvokeStyle, TestOp}
@@ -1349,7 +1350,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
val markers = if (addScalaSerializableMarker) classBTypeFromSymbol(definitions.SerializableClass).toASMType :: Nil else Nil
visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, isSerializable, markers)
if (isSerializable)
- indyLambdaHosts += cnode.name
+ addIndyLambdaImplMethod(cnode.name, implMethodHandle :: Nil)
}
}
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
index 1bff8519ec..d4d532f4df 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
@@ -112,14 +112,6 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
gen(cd.impl)
- val shouldAddLambdaDeserialize = (
- settings.target.value == "jvm-1.8"
- && settings.Ydelambdafy.value == "method"
- && indyLambdaHosts.contains(cnode.name))
-
- if (shouldAddLambdaDeserialize)
- backendUtils.addLambdaDeserialize(cnode)
-
cnode.visitAttribute(thisBType.inlineInfoAttribute.get)
if (AsmUtils.traceClassEnabled && cnode.name.contains(AsmUtils.traceClassPattern))
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
index 7b2686e7a9..0845e440d7 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
@@ -122,7 +122,16 @@ abstract class BTypes {
* inlining: when inlining an indyLambda instruction into a class, we need to make sure the class
* has the method.
*/
- val indyLambdaHosts: mutable.Set[InternalName] = recordPerRunCache(mutable.Set.empty)
+ val indyLambdaImplMethods: mutable.AnyRefMap[InternalName, mutable.LinkedHashSet[asm.Handle]] = recordPerRunCache(mutable.AnyRefMap())
+ def addIndyLambdaImplMethod(hostClass: InternalName, handle: Seq[asm.Handle]): Unit = {
+ indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle
+ }
+ def getIndyLambdaImplMethods(hostClass: InternalName): List[asm.Handle] = {
+ indyLambdaImplMethods.getOrNull(hostClass) match {
+ case null => Nil
+ case xs => xs.toList.distinct
+ }
+ }
/**
* Obtain the BType for a type descriptor or internal name. For class descriptors, the ClassBType
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
index c2010d2828..1dbb18722f 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
@@ -283,7 +283,21 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) {
List(
coreBTypes.jliMethodHandlesLookupRef,
coreBTypes.StringRef,
- coreBTypes.jliMethodTypeRef
+ coreBTypes.jliMethodTypeRef,
+ ArrayBType(jliMethodHandleRef)
+ ),
+ coreBTypes.jliCallSiteRef
+ ).descriptor,
+ /* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get)
+ lazy val lambdaDeserializeAddTargets =
+ new scala.tools.asm.Handle(scala.tools.asm.Opcodes.H_INVOKESTATIC,
+ coreBTypes.srLambdaDeserialize.internalName, "bootstrapAddTargets",
+ MethodBType(
+ List(
+ coreBTypes.jliMethodHandlesLookupRef,
+ coreBTypes.StringRef,
+ coreBTypes.jliMethodTypeRef,
+ ArrayBType(coreBTypes.jliMethodHandleRef)
),
coreBTypes.jliCallSiteRef
).descriptor,
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
index 584b11d4ed..0a54767f76 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
@@ -266,6 +266,9 @@ abstract class GenBCode extends BCodeSyncAndTry {
try {
localOptimizations(item.plain)
setInnerClasses(item.plain)
+ val lambdaImplMethods = getIndyLambdaImplMethods(item.plain.name)
+ if (lambdaImplMethods.nonEmpty)
+ backendUtils.addLambdaDeserialize(item.plain, lambdaImplMethods)
setInnerClasses(item.mirror)
setInnerClasses(item.bean)
addToQ3(item)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
index 83615abc31..d85d85003d 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
@@ -76,7 +76,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
* host a static field in the enclosing class. This allows us to add this method to interfaces
* that define lambdas in default methods.
*/
- def addLambdaDeserialize(classNode: ClassNode): Unit = {
+ def addLambdaDeserialize(classNode: ClassNode, implMethods: List[Handle]): Unit = {
val cw = classNode
// Make sure to reference the ClassBTypes of all types that are used in the code generated
@@ -87,12 +87,13 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
val nilLookupDesc = MethodBType(Nil, jliMethodHandlesLookupRef).descriptor
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectRef).descriptor
+ val addTargetMethodsObjDesc = MethodBType(ObjectRef :: Nil, UNIT).descriptor
{
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
mv.visitCode()
mv.visitVarInsn(ALOAD, 0)
- mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle)
+ mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods: _*)
mv.visitInsn(ARETURN)
mv.visitEnd()
}
@@ -104,16 +105,16 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
* a boolean indicating if the instruction list contains an instantiation of a serializable SAM
* type.
*/
- def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], Boolean) = {
+ def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], List[Handle]) = {
val javaLabelMap = labelMap.asJava
val result = new InsnList
var map = Map.empty[AbstractInsnNode, AbstractInsnNode]
- var hasSerializableClosureInstantiation = false
+ var inlinedTargetHandles = mutable.ListBuffer[Handle]()
for (ins <- methodNode.instructions.iterator.asScala) {
- if (!hasSerializableClosureInstantiation) ins match {
+ ins match {
case callGraph.LambdaMetaFactoryCall(indy, _, _, _) => indy.bsmArgs match {
- case Array(_, _, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 =>
- hasSerializableClosureInstantiation = true
+ case Array(_, targetHandle: Handle, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 =>
+ inlinedTargetHandles += targetHandle
case _ =>
}
case _ =>
@@ -124,7 +125,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
map += ((ins, cloned))
}
}
- (result, map, hasSerializableClosureInstantiation)
+ (result, map, inlinedTargetHandles.toList)
}
def getBoxedUnit: FieldInsnNode = new FieldInsnNode(GETSTATIC, srBoxedUnitRef.internalName, "UNIT", srBoxedUnitRef.descriptor)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
index 9c5a1a9f98..a7916f9c24 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
@@ -277,7 +277,7 @@ class Inliner[BT <: BTypes](val btypes: BT) {
}
case _ => false
}
- val (clonedInstructions, instructionMap, hasSerializableClosureInstantiation) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile)
+ val (clonedInstructions, instructionMap, targetHandles) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile)
// local vars in the callee are shifted by the number of locals at the callsite
val localVarShift = callsiteMethod.maxLocals
@@ -405,10 +405,7 @@ class Inliner[BT <: BTypes](val btypes: BT) {
callsiteMethod.maxStack = math.max(callsiteMethod.maxStack, math.max(stackHeightAtNullCheck, maxStackOfInlinedCode))
- if (hasSerializableClosureInstantiation && !indyLambdaHosts(callsiteClass.internalName)) {
- indyLambdaHosts += callsiteClass.internalName
- addLambdaDeserialize(byteCodeRepository.classNode(callsiteClass.internalName).get)
- }
+ addIndyLambdaImplMethod(callsiteClass.internalName, targetHandles)
callGraph.addIfMissing(callee, calleeDeclarationClass)
diff --git a/src/library/scala/runtime/LambdaDeserialize.java b/src/library/scala/runtime/LambdaDeserialize.java
index e239debf25..a3df868517 100644
--- a/src/library/scala/runtime/LambdaDeserialize.java
+++ b/src/library/scala/runtime/LambdaDeserialize.java
@@ -2,28 +2,41 @@ package scala.runtime;
import java.lang.invoke.*;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
public final class LambdaDeserialize {
+ public static final MethodType DESERIALIZE_LAMBDA_MT = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", LambdaDeserialize.class.getClassLoader());
+ public static final MethodType ADD_TARGET_METHODS_MT = MethodType.fromMethodDescriptorString("([Ljava/lang/invoke/MethodHandle;)V", LambdaDeserialize.class.getClassLoader());
private MethodHandles.Lookup lookup;
private final HashMap<String, MethodHandle> cache = new HashMap<>();
private final LambdaDeserializer$ l = LambdaDeserializer$.MODULE$;
+ private final HashMap<String, MethodHandle> targetMethodMap;
- private LambdaDeserialize(MethodHandles.Lookup lookup) {
+ private LambdaDeserialize(MethodHandles.Lookup lookup, MethodHandle[] targetMethods) {
this.lookup = lookup;
+ targetMethodMap = new HashMap<>(targetMethods.length);
+ for (MethodHandle targetMethod : targetMethods) {
+ MethodHandleInfo info = lookup.revealDirect(targetMethod);
+ String key = nameAndDescriptorKey(info.getName(), info.getMethodType().toMethodDescriptorString());
+ targetMethodMap.put(key, targetMethod);
+ }
}
public Object deserializeLambda(SerializedLambda serialized) {
- return l.deserializeLambda(lookup, cache, serialized);
+ return l.deserializeLambda(lookup, cache, targetMethodMap, serialized);
}
public static CallSite bootstrap(MethodHandles.Lookup lookup, String invokedName,
- MethodType invokedType) throws Throwable {
- MethodType type = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", lookup.getClass().getClassLoader());
- MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", type);
- MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup)).asType(invokedType);
+ MethodType invokedType, MethodHandle... targetMethods) throws Throwable {
+ MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", DESERIALIZE_LAMBDA_MT);
+ MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup, targetMethods)).asType(invokedType);
return new ConstantCallSite(exact);
}
+ public static String nameAndDescriptorKey(String name, String descriptor) {
+ return name + " " + descriptor;
+ }
}
diff --git a/src/library/scala/runtime/LambdaDeserializer.scala b/src/library/scala/runtime/LambdaDeserializer.scala
index ad7d12ba5d..eb168fe445 100644
--- a/src/library/scala/runtime/LambdaDeserializer.scala
+++ b/src/library/scala/runtime/LambdaDeserializer.scala
@@ -31,10 +31,12 @@ object LambdaDeserializer {
* member of the anonymous class created by `LambdaMetaFactory`.
* @return An instance of the functional interface
*/
- def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
+ def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle],
+ targetMethodMap: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
def slashDot(name: String) = name.replaceAll("/", ".")
val loader = lookup.lookupClass().getClassLoader
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
+ val key = LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName, serialized.getImplMethodSignature)
def makeCallSite: CallSite = {
import serialized._
@@ -69,7 +71,15 @@ object LambdaDeserializer {
// Lookup the implementation method
val implMethod: MethodHandle = try {
- findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
+ if (targetMethodMap != null) {
+ if (targetMethodMap.containsKey(key)) {
+ targetMethodMap.get(key)
+ } else {
+ throw new IllegalArgumentException("Illegal lambda deserialization")
+ }
+ } else {
+ findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
+ }
} catch {
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
}
@@ -91,7 +101,6 @@ object LambdaDeserializer {
)
}
- val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
val factory: MethodHandle = if (cache == null) {
makeCallSite.getTarget
} else cache.get(key) match {
diff --git a/test/files/run/lambda-serialization-security.scala b/test/files/run/lambda-serialization-security.scala
new file mode 100644
index 0000000000..08e235b1cb
--- /dev/null
+++ b/test/files/run/lambda-serialization-security.scala
@@ -0,0 +1,47 @@
+import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+
+trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+
+object Test {
+ def main(args: Array[String]): Unit = {
+ roundTrip()
+ roundTripIndySam()
+ }
+
+ def roundTrip(): Unit = {
+ val c = new Capture("Capture")
+ val lambda = (p: Param) => ("a", p, c)
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ val p = new Param
+ assert(reconstituted1.apply(p) == ("a", p, c))
+ val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+
+ val reconstituted3 = serializeDeserialize(reconstituted1)
+ assert(reconstituted3.apply(p) == ("a", p, c))
+
+ val specializedLambda = (p: Int) => List(p, c).length
+ assert(serializeDeserialize(specializedLambda).apply(42) == 2)
+ assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ }
+
+ // lambda targeting a SAM, not a FunctionN (should behave the same way)
+ def roundTripIndySam(): Unit = {
+ val lambda: IntToString = (x: Int) => "yo!" * x
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
+ val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
+ assert(reconstituted1.apply(2) == "yo!yo!")
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+ }
+
+ def serializeDeserialize[T <: AnyRef](obj: T) = {
+ val buffer = new ByteArrayOutputStream
+ val out = new ObjectOutputStream(buffer)
+ out.writeObject(obj)
+ val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ in.readObject.asInstanceOf[T]
+ }
+}
+
+case class Capture(s: String) extends Serializable
+class Param
diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala
index 08e235b1cb..78b4c5d58b 100644
--- a/test/files/run/lambda-serialization.scala
+++ b/test/files/run/lambda-serialization.scala
@@ -1,37 +1,54 @@
-import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
+
+import scala.tools.nsc.util
+
+class C extends java.io.Serializable {
+ val fs = List(
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
+ )
+ private def foo(): Unit = {
+ assert(false, "should not be called!!!")
+ }
+}
-trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+trait FakeSam { def apply(): Unit }
object Test {
def main(args: Array[String]): Unit = {
- roundTrip()
- roundTripIndySam()
+ allRealLambdasRoundTrip()
+ fakeLambdaFailsToDeserialize()
}
- def roundTrip(): Unit = {
- val c = new Capture("Capture")
- val lambda = (p: Param) => ("a", p, c)
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- val p = new Param
- assert(reconstituted1.apply(p) == ("a", p, c))
- val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- assert(reconstituted1.getClass == reconstituted2.getClass)
-
- val reconstituted3 = serializeDeserialize(reconstituted1)
- assert(reconstituted3.apply(p) == ("a", p, c))
-
- val specializedLambda = (p: Int) => List(p, c).length
- assert(serializeDeserialize(specializedLambda).apply(42) == 2)
- assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ def allRealLambdasRoundTrip(): Unit = {
+ new C().fs.map(x => serializeDeserialize(x).apply())
}
- // lambda targeting a SAM, not a FunctionN (should behave the same way)
- def roundTripIndySam(): Unit = {
- val lambda: IntToString = (x: Int) => "yo!" * x
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
- val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
- assert(reconstituted1.apply(2) == "yo!yo!")
- assert(reconstituted1.getClass == reconstituted2.getClass)
+ def fakeLambdaFailsToDeserialize(): Unit = {
+ val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
+ MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
+ try {
+ serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
+ assert(false)
+ } catch {
+ case ex: Exception =>
+ val stackTrace = util.stackTraceString(ex)
+ assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
+ }
}
def serializeDeserialize[T <: AnyRef](obj: T) = {
@@ -43,5 +60,3 @@ object Test {
}
}
-case class Capture(s: String) extends Serializable
-class Param
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java
index 069eb4aab6..ba52e979cc 100644
--- a/test/junit/scala/runtime/LambdaDeserializerTest.java
+++ b/test/junit/scala/runtime/LambdaDeserializerTest.java
@@ -97,7 +97,7 @@ public final class LambdaDeserializerTest {
private void checkIllegalAccess(SerializedLambda serialized) {
try {
- LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
+ LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, null, serialized);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
@@ -130,7 +130,7 @@ public final class LambdaDeserializerTest {
@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
- return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
+ return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, null, writeReplace(f1));
} catch (Exception e) {
throw new RuntimeException(e);
}