diff options
6 files changed, 82 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala index 783c89584e..6aa3a62295 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala @@ -685,24 +685,50 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters { /** * Add: - * + * private static java.util.Map $deserializeLambdaCache$ = null * private static Object $deserializeLambda$(SerializedLambda l) { - * return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, l); + * var cache = $deserializeLambdaCache$ + * if (cache eq null) { + * cache = new java.util.HashMap() + * $deserializeLambdaCache$ = cache + * } + * return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l); * } - * @param jclass */ - // TODO add a static cache field to the class, and pass that as the second argument to `deserializeLambda`. - // This will make the test at run/lambda-serialization.scala:15 work - def addLambdaDeserialize(jclass: asm.ClassVisitor): Unit = { + def addLambdaDeserialize(clazz: Symbol, jclass: asm.ClassVisitor): Unit = { val cw = jclass import scala.tools.asm.Opcodes._ + + // Need to force creation of BTypes for these as `getCommonSuperClass` is called on + // automatically computing the max stack size (`visitMaxs`) during method writing. + javaUtilHashMapReference + javaUtilMapReference + cw.visitInnerClass("java/lang/invoke/MethodHandles$Lookup", "java/lang/invoke/MethodHandles", "Lookup", ACC_PUBLIC + ACC_FINAL + ACC_STATIC) { + val fv = cw.visitField(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambdaCache$", "Ljava/util/Map;", null, null) + fv.visitEnd() + } + + { val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", "(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", null, null) mv.visitCode() + mv.visitFieldInsn(GETSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;") + mv.visitVarInsn(ASTORE, 1) + mv.visitVarInsn(ALOAD, 1) + val l0 = new asm.Label() + mv.visitJumpInsn(IFNONNULL, l0) + mv.visitTypeInsn(NEW, "java/util/HashMap") + mv.visitInsn(DUP) + mv.visitMethodInsn(INVOKESPECIAL, "java/util/HashMap", "<init>", "()V", false) + mv.visitVarInsn(ASTORE, 1) + mv.visitVarInsn(ALOAD, 1) + mv.visitFieldInsn(PUTSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;") + mv.visitLabel(l0) + mv.visitFrame(asm.Opcodes.F_APPEND,1, Array("java/util/Map"), 0, null) mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup", "()Ljava/lang/invoke/MethodHandles$Lookup;", false) - mv.visitInsn(asm.Opcodes.ACONST_NULL) + mv.visitVarInsn(ALOAD, 1) mv.visitVarInsn(ALOAD, 0) mv.visitMethodInsn(INVOKESTATIC, "scala/compat/java8/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false) mv.visitInsn(ARETURN) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index b2011f8e0c..a2fd22d24c 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -131,7 +131,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { && indyLambdaHosts.contains(claszSymbol)) if (shouldAddLambdaDeserialize) - addLambdaDeserialize(cnode) + addLambdaDeserialize(claszSymbol, cnode) addInnerClassesASM(cnode, innerClassBufferASM.toList) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala index 492fe3ae79..00ca096e59 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala @@ -114,6 +114,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) { lazy val jioSerializableReference : ClassBType = classBTypeFromSymbol(JavaSerializableClass) // java/io/Serializable lazy val scalaSerializableReference : ClassBType = classBTypeFromSymbol(SerializableClass) // scala/Serializable lazy val classCastExceptionReference : ClassBType = classBTypeFromSymbol(ClassCastExceptionClass) // java/lang/ClassCastException + lazy val javaUtilMapReference : ClassBType = classBTypeFromSymbol(JavaUtilMap) // java/util/Map + lazy val javaUtilHashMapReference : ClassBType = classBTypeFromSymbol(JavaUtilHashMap) // java/util/HashMap lazy val srBooleanRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.BooleanRef]) lazy val srByteRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.ByteRef]) @@ -258,6 +260,8 @@ final class CoreBTypesProxy[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: def jioSerializableReference : ClassBType = _coreBTypes.jioSerializableReference def scalaSerializableReference : ClassBType = _coreBTypes.scalaSerializableReference def classCastExceptionReference : ClassBType = _coreBTypes.classCastExceptionReference + def javaUtilMapReference : ClassBType = _coreBTypes.javaUtilMapReference + def javaUtilHashMapReference : ClassBType = _coreBTypes.javaUtilHashMapReference def srBooleanRef : ClassBType = _coreBTypes.srBooleanRef def srByteRef : ClassBType = _coreBTypes.srByteRef diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index 806fc37617..f3dd6a3280 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -369,6 +369,8 @@ trait Definitions extends api.StandardDefinitions { lazy val JavaEnumClass = requiredClass[java.lang.Enum[_]] lazy val RemoteInterfaceClass = requiredClass[java.rmi.Remote] lazy val RemoteExceptionClass = requiredClass[java.rmi.RemoteException] + lazy val JavaUtilMap = requiredClass[java.util.Map[_, _]] + lazy val JavaUtilHashMap = requiredClass[java.util.HashMap[_, _]] lazy val ByNameParamClass = specialPolyClass(tpnme.BYNAME_PARAM_CLASS_NAME, COVARIANT)(_ => AnyTpe) lazy val JavaRepeatedParamClass = specialPolyClass(tpnme.JAVA_REPEATED_PARAM_CLASS_NAME, COVARIANT)(tparam => arrayType(tparam.tpe)) diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala index 8c03ee7ca3..ea213cadd9 100644 --- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala +++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala @@ -255,6 +255,8 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => definitions.JavaEnumClass definitions.RemoteInterfaceClass definitions.RemoteExceptionClass + definitions.JavaUtilMap + definitions.JavaUtilHashMap definitions.ByNameParamClass definitions.JavaRepeatedParamClass definitions.RepeatedParamClass diff --git a/test/files/run/lambda-serialization-gc.scala b/test/files/run/lambda-serialization-gc.scala new file mode 100644 index 0000000000..8fa0b4b402 --- /dev/null +++ b/test/files/run/lambda-serialization-gc.scala @@ -0,0 +1,40 @@ +import java.io._ + +import java.net.URLClassLoader + +class C { + def serializeDeserialize[T <: AnyRef](obj: T) = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } + + serializeDeserialize((c: String) => c.length) +} + +object Test { + def main(args: Array[String]): Unit = { + test() + } + + def test(): Unit = { + val loader = getClass.getClassLoader.asInstanceOf[URLClassLoader] + val loaderCClass = classOf[C] + def deserializedInThrowawayClassloader = { + val throwawayLoader: java.net.URLClassLoader = new java.net.URLClassLoader(loader.getURLs, ClassLoader.getSystemClassLoader) { + val maxMemory = Runtime.getRuntime.maxMemory() + val junk = new Array[Byte]((maxMemory / 2).toInt) + } + val clazz = throwawayLoader.loadClass("C") + assert(clazz != loaderCClass) + clazz.newInstance() + } + (1 to 4) foreach { i => + // This would OOM by the third iteration if we leaked `throwawayLoader` during + // deserialization. + deserializedInThrowawayClassloader + } + } +} |