summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala3
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala8
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala11
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala16
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala3
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala17
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala7
-rw-r--r--src/library/scala/runtime/LambdaDeserialize.java25
-rw-r--r--src/library/scala/runtime/LambdaDeserializer.scala15
-rw-r--r--test/files/run/lambda-serialization-security.scala47
-rw-r--r--test/files/run/lambda-serialization.scala71
-rw-r--r--test/junit/scala/runtime/LambdaDeserializerTest.java4
12 files changed, 164 insertions, 63 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
index d5c4b5e201..6f9682f434 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala
@@ -14,6 +14,7 @@ import scala.reflect.internal.Flags
import scala.tools.asm
import GenBCode._
import BackendReporting._
+import scala.collection.mutable
import scala.tools.asm.Opcodes
import scala.tools.asm.tree.{MethodInsnNode, MethodNode}
import scala.tools.nsc.backend.jvm.BCodeHelpers.{InvokeStyle, TestOp}
@@ -1349,7 +1350,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
val markers = if (addScalaSerializableMarker) classBTypeFromSymbol(definitions.SerializableClass).toASMType :: Nil else Nil
visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, isSerializable, markers)
if (isSerializable)
- indyLambdaHosts += cnode.name
+ addIndyLambdaImplMethod(cnode.name, implMethodHandle :: Nil)
}
}
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
index 1bff8519ec..d4d532f4df 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
@@ -112,14 +112,6 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
gen(cd.impl)
- val shouldAddLambdaDeserialize = (
- settings.target.value == "jvm-1.8"
- && settings.Ydelambdafy.value == "method"
- && indyLambdaHosts.contains(cnode.name))
-
- if (shouldAddLambdaDeserialize)
- backendUtils.addLambdaDeserialize(cnode)
-
cnode.visitAttribute(thisBType.inlineInfoAttribute.get)
if (AsmUtils.traceClassEnabled && cnode.name.contains(AsmUtils.traceClassPattern))
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
index 7b2686e7a9..0845e440d7 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala
@@ -122,7 +122,16 @@ abstract class BTypes {
* inlining: when inlining an indyLambda instruction into a class, we need to make sure the class
* has the method.
*/
- val indyLambdaHosts: mutable.Set[InternalName] = recordPerRunCache(mutable.Set.empty)
+ val indyLambdaImplMethods: mutable.AnyRefMap[InternalName, mutable.LinkedHashSet[asm.Handle]] = recordPerRunCache(mutable.AnyRefMap())
+ def addIndyLambdaImplMethod(hostClass: InternalName, handle: Seq[asm.Handle]): Unit = {
+ indyLambdaImplMethods.getOrElseUpdate(hostClass, mutable.LinkedHashSet()) ++= handle
+ }
+ def getIndyLambdaImplMethods(hostClass: InternalName): List[asm.Handle] = {
+ indyLambdaImplMethods.getOrNull(hostClass) match {
+ case null => Nil
+ case xs => xs.toList.distinct
+ }
+ }
/**
* Obtain the BType for a type descriptor or internal name. For class descriptors, the ClassBType
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
index c2010d2828..1dbb18722f 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
@@ -283,7 +283,21 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) {
List(
coreBTypes.jliMethodHandlesLookupRef,
coreBTypes.StringRef,
- coreBTypes.jliMethodTypeRef
+ coreBTypes.jliMethodTypeRef,
+ ArrayBType(jliMethodHandleRef)
+ ),
+ coreBTypes.jliCallSiteRef
+ ).descriptor,
+ /* itf = */ coreBTypes.srLambdaDeserialize.isInterface.get)
+ lazy val lambdaDeserializeAddTargets =
+ new scala.tools.asm.Handle(scala.tools.asm.Opcodes.H_INVOKESTATIC,
+ coreBTypes.srLambdaDeserialize.internalName, "bootstrapAddTargets",
+ MethodBType(
+ List(
+ coreBTypes.jliMethodHandlesLookupRef,
+ coreBTypes.StringRef,
+ coreBTypes.jliMethodTypeRef,
+ ArrayBType(coreBTypes.jliMethodHandleRef)
),
coreBTypes.jliCallSiteRef
).descriptor,
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
index 584b11d4ed..0a54767f76 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala
@@ -266,6 +266,9 @@ abstract class GenBCode extends BCodeSyncAndTry {
try {
localOptimizations(item.plain)
setInnerClasses(item.plain)
+ val lambdaImplMethods = getIndyLambdaImplMethods(item.plain.name)
+ if (lambdaImplMethods.nonEmpty)
+ backendUtils.addLambdaDeserialize(item.plain, lambdaImplMethods)
setInnerClasses(item.mirror)
setInnerClasses(item.bean)
addToQ3(item)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
index 83615abc31..d85d85003d 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/analysis/BackendUtils.scala
@@ -76,7 +76,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
* host a static field in the enclosing class. This allows us to add this method to interfaces
* that define lambdas in default methods.
*/
- def addLambdaDeserialize(classNode: ClassNode): Unit = {
+ def addLambdaDeserialize(classNode: ClassNode, implMethods: List[Handle]): Unit = {
val cw = classNode
// Make sure to reference the ClassBTypes of all types that are used in the code generated
@@ -87,12 +87,13 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
val nilLookupDesc = MethodBType(Nil, jliMethodHandlesLookupRef).descriptor
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectRef).descriptor
+ val addTargetMethodsObjDesc = MethodBType(ObjectRef :: Nil, UNIT).descriptor
{
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
mv.visitCode()
mv.visitVarInsn(ALOAD, 0)
- mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle)
+ mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, implMethods: _*)
mv.visitInsn(ARETURN)
mv.visitEnd()
}
@@ -104,16 +105,16 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
* a boolean indicating if the instruction list contains an instantiation of a serializable SAM
* type.
*/
- def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], Boolean) = {
+ def cloneInstructions(methodNode: MethodNode, labelMap: Map[LabelNode, LabelNode], keepLineNumbers: Boolean): (InsnList, Map[AbstractInsnNode, AbstractInsnNode], List[Handle]) = {
val javaLabelMap = labelMap.asJava
val result = new InsnList
var map = Map.empty[AbstractInsnNode, AbstractInsnNode]
- var hasSerializableClosureInstantiation = false
+ var inlinedTargetHandles = mutable.ListBuffer[Handle]()
for (ins <- methodNode.instructions.iterator.asScala) {
- if (!hasSerializableClosureInstantiation) ins match {
+ ins match {
case callGraph.LambdaMetaFactoryCall(indy, _, _, _) => indy.bsmArgs match {
- case Array(_, _, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 =>
- hasSerializableClosureInstantiation = true
+ case Array(_, targetHandle: Handle, _, flags: Integer, xs@_*) if (flags.intValue & LambdaMetafactory.FLAG_SERIALIZABLE) != 0 =>
+ inlinedTargetHandles += targetHandle
case _ =>
}
case _ =>
@@ -124,7 +125,7 @@ class BackendUtils[BT <: BTypes](val btypes: BT) {
map += ((ins, cloned))
}
}
- (result, map, hasSerializableClosureInstantiation)
+ (result, map, inlinedTargetHandles.toList)
}
def getBoxedUnit: FieldInsnNode = new FieldInsnNode(GETSTATIC, srBoxedUnitRef.internalName, "UNIT", srBoxedUnitRef.descriptor)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
index 9c5a1a9f98..a7916f9c24 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala
@@ -277,7 +277,7 @@ class Inliner[BT <: BTypes](val btypes: BT) {
}
case _ => false
}
- val (clonedInstructions, instructionMap, hasSerializableClosureInstantiation) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile)
+ val (clonedInstructions, instructionMap, targetHandles) = cloneInstructions(callee, labelsMap, keepLineNumbers = sameSourceFile)
// local vars in the callee are shifted by the number of locals at the callsite
val localVarShift = callsiteMethod.maxLocals
@@ -405,10 +405,7 @@ class Inliner[BT <: BTypes](val btypes: BT) {
callsiteMethod.maxStack = math.max(callsiteMethod.maxStack, math.max(stackHeightAtNullCheck, maxStackOfInlinedCode))
- if (hasSerializableClosureInstantiation && !indyLambdaHosts(callsiteClass.internalName)) {
- indyLambdaHosts += callsiteClass.internalName
- addLambdaDeserialize(byteCodeRepository.classNode(callsiteClass.internalName).get)
- }
+ addIndyLambdaImplMethod(callsiteClass.internalName, targetHandles)
callGraph.addIfMissing(callee, calleeDeclarationClass)
diff --git a/src/library/scala/runtime/LambdaDeserialize.java b/src/library/scala/runtime/LambdaDeserialize.java
index e239debf25..a3df868517 100644
--- a/src/library/scala/runtime/LambdaDeserialize.java
+++ b/src/library/scala/runtime/LambdaDeserialize.java
@@ -2,28 +2,41 @@ package scala.runtime;
import java.lang.invoke.*;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
public final class LambdaDeserialize {
+ public static final MethodType DESERIALIZE_LAMBDA_MT = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", LambdaDeserialize.class.getClassLoader());
+ public static final MethodType ADD_TARGET_METHODS_MT = MethodType.fromMethodDescriptorString("([Ljava/lang/invoke/MethodHandle;)V", LambdaDeserialize.class.getClassLoader());
private MethodHandles.Lookup lookup;
private final HashMap<String, MethodHandle> cache = new HashMap<>();
private final LambdaDeserializer$ l = LambdaDeserializer$.MODULE$;
+ private final HashMap<String, MethodHandle> targetMethodMap;
- private LambdaDeserialize(MethodHandles.Lookup lookup) {
+ private LambdaDeserialize(MethodHandles.Lookup lookup, MethodHandle[] targetMethods) {
this.lookup = lookup;
+ targetMethodMap = new HashMap<>(targetMethods.length);
+ for (MethodHandle targetMethod : targetMethods) {
+ MethodHandleInfo info = lookup.revealDirect(targetMethod);
+ String key = nameAndDescriptorKey(info.getName(), info.getMethodType().toMethodDescriptorString());
+ targetMethodMap.put(key, targetMethod);
+ }
}
public Object deserializeLambda(SerializedLambda serialized) {
- return l.deserializeLambda(lookup, cache, serialized);
+ return l.deserializeLambda(lookup, cache, targetMethodMap, serialized);
}
public static CallSite bootstrap(MethodHandles.Lookup lookup, String invokedName,
- MethodType invokedType) throws Throwable {
- MethodType type = MethodType.fromMethodDescriptorString("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", lookup.getClass().getClassLoader());
- MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", type);
- MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup)).asType(invokedType);
+ MethodType invokedType, MethodHandle... targetMethods) throws Throwable {
+ MethodHandle deserializeLambda = lookup.findVirtual(LambdaDeserialize.class, "deserializeLambda", DESERIALIZE_LAMBDA_MT);
+ MethodHandle exact = deserializeLambda.bindTo(new LambdaDeserialize(lookup, targetMethods)).asType(invokedType);
return new ConstantCallSite(exact);
}
+ public static String nameAndDescriptorKey(String name, String descriptor) {
+ return name + " " + descriptor;
+ }
}
diff --git a/src/library/scala/runtime/LambdaDeserializer.scala b/src/library/scala/runtime/LambdaDeserializer.scala
index ad7d12ba5d..eb168fe445 100644
--- a/src/library/scala/runtime/LambdaDeserializer.scala
+++ b/src/library/scala/runtime/LambdaDeserializer.scala
@@ -31,10 +31,12 @@ object LambdaDeserializer {
* member of the anonymous class created by `LambdaMetaFactory`.
* @return An instance of the functional interface
*/
- def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
+ def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle],
+ targetMethodMap: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
def slashDot(name: String) = name.replaceAll("/", ".")
val loader = lookup.lookupClass().getClassLoader
val implClass = loader.loadClass(slashDot(serialized.getImplClass))
+ val key = LambdaDeserialize.nameAndDescriptorKey(serialized.getImplMethodName, serialized.getImplMethodSignature)
def makeCallSite: CallSite = {
import serialized._
@@ -69,7 +71,15 @@ object LambdaDeserializer {
// Lookup the implementation method
val implMethod: MethodHandle = try {
- findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
+ if (targetMethodMap != null) {
+ if (targetMethodMap.containsKey(key)) {
+ targetMethodMap.get(key)
+ } else {
+ throw new IllegalArgumentException("Illegal lambda deserialization")
+ }
+ } else {
+ findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
+ }
} catch {
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
}
@@ -91,7 +101,6 @@ object LambdaDeserializer {
)
}
- val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
val factory: MethodHandle = if (cache == null) {
makeCallSite.getTarget
} else cache.get(key) match {
diff --git a/test/files/run/lambda-serialization-security.scala b/test/files/run/lambda-serialization-security.scala
new file mode 100644
index 0000000000..08e235b1cb
--- /dev/null
+++ b/test/files/run/lambda-serialization-security.scala
@@ -0,0 +1,47 @@
+import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+
+trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+
+object Test {
+ def main(args: Array[String]): Unit = {
+ roundTrip()
+ roundTripIndySam()
+ }
+
+ def roundTrip(): Unit = {
+ val c = new Capture("Capture")
+ val lambda = (p: Param) => ("a", p, c)
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ val p = new Param
+ assert(reconstituted1.apply(p) == ("a", p, c))
+ val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+
+ val reconstituted3 = serializeDeserialize(reconstituted1)
+ assert(reconstituted3.apply(p) == ("a", p, c))
+
+ val specializedLambda = (p: Int) => List(p, c).length
+ assert(serializeDeserialize(specializedLambda).apply(42) == 2)
+ assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ }
+
+ // lambda targeting a SAM, not a FunctionN (should behave the same way)
+ def roundTripIndySam(): Unit = {
+ val lambda: IntToString = (x: Int) => "yo!" * x
+ val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
+ val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
+ assert(reconstituted1.apply(2) == "yo!yo!")
+ assert(reconstituted1.getClass == reconstituted2.getClass)
+ }
+
+ def serializeDeserialize[T <: AnyRef](obj: T) = {
+ val buffer = new ByteArrayOutputStream
+ val out = new ObjectOutputStream(buffer)
+ out.writeObject(obj)
+ val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ in.readObject.asInstanceOf[T]
+ }
+}
+
+case class Capture(s: String) extends Serializable
+class Param
diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala
index 08e235b1cb..78b4c5d58b 100644
--- a/test/files/run/lambda-serialization.scala
+++ b/test/files/run/lambda-serialization.scala
@@ -1,37 +1,54 @@
-import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
+
+import scala.tools.nsc.util
+
+class C extends java.io.Serializable {
+ val fs = List(
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
+ () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
+ )
+ private def foo(): Unit = {
+ assert(false, "should not be called!!!")
+ }
+}
-trait IntToString extends java.io.Serializable { def apply(i: Int): String }
+trait FakeSam { def apply(): Unit }
object Test {
def main(args: Array[String]): Unit = {
- roundTrip()
- roundTripIndySam()
+ allRealLambdasRoundTrip()
+ fakeLambdaFailsToDeserialize()
}
- def roundTrip(): Unit = {
- val c = new Capture("Capture")
- val lambda = (p: Param) => ("a", p, c)
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- val p = new Param
- assert(reconstituted1.apply(p) == ("a", p, c))
- val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any]
- assert(reconstituted1.getClass == reconstituted2.getClass)
-
- val reconstituted3 = serializeDeserialize(reconstituted1)
- assert(reconstituted3.apply(p) == ("a", p, c))
-
- val specializedLambda = (p: Int) => List(p, c).length
- assert(serializeDeserialize(specializedLambda).apply(42) == 2)
- assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2)
+ def allRealLambdasRoundTrip(): Unit = {
+ new C().fs.map(x => serializeDeserialize(x).apply())
}
- // lambda targeting a SAM, not a FunctionN (should behave the same way)
- def roundTripIndySam(): Unit = {
- val lambda: IntToString = (x: Int) => "yo!" * x
- val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[IntToString]
- val reconstituted2 = serializeDeserialize(reconstituted1).asInstanceOf[IntToString]
- assert(reconstituted1.apply(2) == "yo!yo!")
- assert(reconstituted1.getClass == reconstituted2.getClass)
+ def fakeLambdaFailsToDeserialize(): Unit = {
+ val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
+ MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
+ try {
+ serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
+ assert(false)
+ } catch {
+ case ex: Exception =>
+ val stackTrace = util.stackTraceString(ex)
+ assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
+ }
}
def serializeDeserialize[T <: AnyRef](obj: T) = {
@@ -43,5 +60,3 @@ object Test {
}
}
-case class Capture(s: String) extends Serializable
-class Param
diff --git a/test/junit/scala/runtime/LambdaDeserializerTest.java b/test/junit/scala/runtime/LambdaDeserializerTest.java
index 069eb4aab6..ba52e979cc 100644
--- a/test/junit/scala/runtime/LambdaDeserializerTest.java
+++ b/test/junit/scala/runtime/LambdaDeserializerTest.java
@@ -97,7 +97,7 @@ public final class LambdaDeserializerTest {
private void checkIllegalAccess(SerializedLambda serialized) {
try {
- LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
+ LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, null, serialized);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
@@ -130,7 +130,7 @@ public final class LambdaDeserializerTest {
@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
- return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
+ return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, null, writeReplace(f1));
} catch (Exception e) {
throw new RuntimeException(e);
}