summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2015-05-16 21:02:56 +1000
committerJason Zaugg <jzaugg@gmail.com>2015-05-18 09:07:54 +1000
commit1d8c63277e97c57e12fa9864a2d238d4f54c10f0 (patch)
tree949d2125c97a74bdc0eee5d581684ff93f1f1a77
parentafa2ff9f76123ab982dc5bb2f1110bb58e75c68c (diff)
downloadscala-1d8c63277e97c57e12fa9864a2d238d4f54c10f0.tar.gz
scala-1d8c63277e97c57e12fa9864a2d238d4f54c10f0.tar.bz2
scala-1d8c63277e97c57e12fa9864a2d238d4f54c10f0.zip
[indylambda] Enable caching for lambda deserialization
We add a static field to each class that defines lambdas that will hold a `ju.Map[String, MethodHandle]` to cache references to the constructors of the classes originally created by `LambdaMetafactory`. The cache is initially null, and created on the first deserialization. In case of a race between two threads deserializing the first lambda hosted by a class, the last one to finish will clobber the one-element cache of the first. This lack of strong guarantees mirrors the current policy in `LambdaDeserializer`. We should consider whether to strengthen the combinaed guarantee here. A useful benchmark would be those of the invokedynamic instruction, which allows multiple threads to call the boostrap method in parallel, but guarantees that if that happens, the results of all but one will be discarded: > If several threads simultaneously execute the bootstrap method for > the same dynamic call site, the Java Virtual Machine must choose > one returned call site object and install it visibly to all threads. We could meet this guarantee easily, albeit excessively, by synchronizing `$deserializeLambda$`. But a more fine grained approach is possible and desirable. A test is included that shows we are able to garbage collect classloaders of classes that have hosted lambda deserialization.
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala40
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala2
-rw-r--r--src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala4
-rw-r--r--src/reflect/scala/reflect/internal/Definitions.scala2
-rw-r--r--src/reflect/scala/reflect/runtime/JavaUniverseForce.scala2
-rw-r--r--test/files/run/lambda-serialization-gc.scala40
6 files changed, 82 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala
index 783c89584e..6aa3a62295 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala
@@ -685,24 +685,50 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters {
/**
* Add:
- *
+ * private static java.util.Map $deserializeLambdaCache$ = null
* private static Object $deserializeLambda$(SerializedLambda l) {
- * return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, l);
+ * var cache = $deserializeLambdaCache$
+ * if (cache eq null) {
+ * cache = new java.util.HashMap()
+ * $deserializeLambdaCache$ = cache
+ * }
+ * return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l);
* }
- * @param jclass
*/
- // TODO add a static cache field to the class, and pass that as the second argument to `deserializeLambda`.
- // This will make the test at run/lambda-serialization.scala:15 work
- def addLambdaDeserialize(jclass: asm.ClassVisitor): Unit = {
+ def addLambdaDeserialize(clazz: Symbol, jclass: asm.ClassVisitor): Unit = {
val cw = jclass
import scala.tools.asm.Opcodes._
+
+ // Need to force creation of BTypes for these as `getCommonSuperClass` is called on
+ // automatically computing the max stack size (`visitMaxs`) during method writing.
+ javaUtilHashMapReference
+ javaUtilMapReference
+
cw.visitInnerClass("java/lang/invoke/MethodHandles$Lookup", "java/lang/invoke/MethodHandles", "Lookup", ACC_PUBLIC + ACC_FINAL + ACC_STATIC)
{
+ val fv = cw.visitField(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambdaCache$", "Ljava/util/Map;", null, null)
+ fv.visitEnd()
+ }
+
+ {
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", "(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", null, null)
mv.visitCode()
+ mv.visitFieldInsn(GETSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;")
+ mv.visitVarInsn(ASTORE, 1)
+ mv.visitVarInsn(ALOAD, 1)
+ val l0 = new asm.Label()
+ mv.visitJumpInsn(IFNONNULL, l0)
+ mv.visitTypeInsn(NEW, "java/util/HashMap")
+ mv.visitInsn(DUP)
+ mv.visitMethodInsn(INVOKESPECIAL, "java/util/HashMap", "<init>", "()V", false)
+ mv.visitVarInsn(ASTORE, 1)
+ mv.visitVarInsn(ALOAD, 1)
+ mv.visitFieldInsn(PUTSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;")
+ mv.visitLabel(l0)
+ mv.visitFrame(asm.Opcodes.F_APPEND,1, Array("java/util/Map"), 0, null)
mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup", "()Ljava/lang/invoke/MethodHandles$Lookup;", false)
- mv.visitInsn(asm.Opcodes.ACONST_NULL)
+ mv.visitVarInsn(ALOAD, 1)
mv.visitVarInsn(ALOAD, 0)
mv.visitMethodInsn(INVOKESTATIC, "scala/compat/java8/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false)
mv.visitInsn(ARETURN)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
index b2011f8e0c..a2fd22d24c 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala
@@ -131,7 +131,7 @@ abstract class BCodeSkelBuilder extends BCodeHelpers {
&& indyLambdaHosts.contains(claszSymbol))
if (shouldAddLambdaDeserialize)
- addLambdaDeserialize(cnode)
+ addLambdaDeserialize(claszSymbol, cnode)
addInnerClassesASM(cnode, innerClassBufferASM.toList)
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
index 492fe3ae79..00ca096e59 100644
--- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
+++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala
@@ -114,6 +114,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) {
lazy val jioSerializableReference : ClassBType = classBTypeFromSymbol(JavaSerializableClass) // java/io/Serializable
lazy val scalaSerializableReference : ClassBType = classBTypeFromSymbol(SerializableClass) // scala/Serializable
lazy val classCastExceptionReference : ClassBType = classBTypeFromSymbol(ClassCastExceptionClass) // java/lang/ClassCastException
+ lazy val javaUtilMapReference : ClassBType = classBTypeFromSymbol(JavaUtilMap) // java/util/Map
+ lazy val javaUtilHashMapReference : ClassBType = classBTypeFromSymbol(JavaUtilHashMap) // java/util/HashMap
lazy val srBooleanRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.BooleanRef])
lazy val srByteRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.ByteRef])
@@ -258,6 +260,8 @@ final class CoreBTypesProxy[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes:
def jioSerializableReference : ClassBType = _coreBTypes.jioSerializableReference
def scalaSerializableReference : ClassBType = _coreBTypes.scalaSerializableReference
def classCastExceptionReference : ClassBType = _coreBTypes.classCastExceptionReference
+ def javaUtilMapReference : ClassBType = _coreBTypes.javaUtilMapReference
+ def javaUtilHashMapReference : ClassBType = _coreBTypes.javaUtilHashMapReference
def srBooleanRef : ClassBType = _coreBTypes.srBooleanRef
def srByteRef : ClassBType = _coreBTypes.srByteRef
diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala
index 806fc37617..f3dd6a3280 100644
--- a/src/reflect/scala/reflect/internal/Definitions.scala
+++ b/src/reflect/scala/reflect/internal/Definitions.scala
@@ -369,6 +369,8 @@ trait Definitions extends api.StandardDefinitions {
lazy val JavaEnumClass = requiredClass[java.lang.Enum[_]]
lazy val RemoteInterfaceClass = requiredClass[java.rmi.Remote]
lazy val RemoteExceptionClass = requiredClass[java.rmi.RemoteException]
+ lazy val JavaUtilMap = requiredClass[java.util.Map[_, _]]
+ lazy val JavaUtilHashMap = requiredClass[java.util.HashMap[_, _]]
lazy val ByNameParamClass = specialPolyClass(tpnme.BYNAME_PARAM_CLASS_NAME, COVARIANT)(_ => AnyTpe)
lazy val JavaRepeatedParamClass = specialPolyClass(tpnme.JAVA_REPEATED_PARAM_CLASS_NAME, COVARIANT)(tparam => arrayType(tparam.tpe))
diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala
index 8c03ee7ca3..ea213cadd9 100644
--- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala
+++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala
@@ -255,6 +255,8 @@ trait JavaUniverseForce { self: runtime.JavaUniverse =>
definitions.JavaEnumClass
definitions.RemoteInterfaceClass
definitions.RemoteExceptionClass
+ definitions.JavaUtilMap
+ definitions.JavaUtilHashMap
definitions.ByNameParamClass
definitions.JavaRepeatedParamClass
definitions.RepeatedParamClass
diff --git a/test/files/run/lambda-serialization-gc.scala b/test/files/run/lambda-serialization-gc.scala
new file mode 100644
index 0000000000..8fa0b4b402
--- /dev/null
+++ b/test/files/run/lambda-serialization-gc.scala
@@ -0,0 +1,40 @@
+import java.io._
+
+import java.net.URLClassLoader
+
+class C {
+ def serializeDeserialize[T <: AnyRef](obj: T) = {
+ val buffer = new ByteArrayOutputStream
+ val out = new ObjectOutputStream(buffer)
+ out.writeObject(obj)
+ val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
+ in.readObject.asInstanceOf[T]
+ }
+
+ serializeDeserialize((c: String) => c.length)
+}
+
+object Test {
+ def main(args: Array[String]): Unit = {
+ test()
+ }
+
+ def test(): Unit = {
+ val loader = getClass.getClassLoader.asInstanceOf[URLClassLoader]
+ val loaderCClass = classOf[C]
+ def deserializedInThrowawayClassloader = {
+ val throwawayLoader: java.net.URLClassLoader = new java.net.URLClassLoader(loader.getURLs, ClassLoader.getSystemClassLoader) {
+ val maxMemory = Runtime.getRuntime.maxMemory()
+ val junk = new Array[Byte]((maxMemory / 2).toInt)
+ }
+ val clazz = throwawayLoader.loadClass("C")
+ assert(clazz != loaderCClass)
+ clazz.newInstance()
+ }
+ (1 to 4) foreach { i =>
+ // This would OOM by the third iteration if we leaked `throwawayLoader` during
+ // deserialization.
+ deserializedInThrowawayClassloader
+ }
+ }
+}