diff options
author | Adriaan Moors <adriaan.moors@typesafe.com> | 2015-05-21 16:09:29 -0700 |
---|---|---|
committer | Adriaan Moors <adriaan.moors@typesafe.com> | 2015-05-21 16:09:29 -0700 |
commit | 97543db610bcccdd0243091ba3bebc23061102e8 (patch) | |
tree | e69b83f3885794f1b34e08960124cb198ab15544 | |
parent | 0c360f6d285fdea1fb240248b6c857cc942fbede (diff) | |
parent | 1d8c63277e97c57e12fa9864a2d238d4f54c10f0 (diff) | |
download | scala-97543db610bcccdd0243091ba3bebc23061102e8.tar.gz scala-97543db610bcccdd0243091ba3bebc23061102e8.tar.bz2 scala-97543db610bcccdd0243091ba3bebc23061102e8.zip |
Merge pull request #4501 from retronym/topic/indylambda-serialization
[indylambda] Support lambda {de}serialization
11 files changed, 179 insertions, 22 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index 8ebe27e61b..40ba0c010b 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -33,7 +33,6 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { * Functionality to build the body of ASM MethodNode, except for `synchronized` and `try` expressions. */ abstract class PlainBodyBuilder(cunit: CompilationUnit) extends PlainSkelBuilder(cunit) { - import icodes.TestOp import icodes.opcodes.InvokeStyle @@ -1287,38 +1286,42 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol) { val isStaticMethod = lambdaTarget.hasFlag(Flags.STATIC) + def asmType(sym: Symbol) = classBTypeFromSymbol(sym).toASMType - val targetHandle = + val implMethodHandle = new asm.Handle(if (lambdaTarget.hasFlag(Flags.STATIC)) asm.Opcodes.H_INVOKESTATIC else asm.Opcodes.H_INVOKEVIRTUAL, classBTypeFromSymbol(lambdaTarget.owner).internalName, lambdaTarget.name.toString, asmMethodType(lambdaTarget).descriptor) - val receiver = if (isStaticMethod) None else Some(lambdaTarget.owner) + val receiver = if (isStaticMethod) Nil else lambdaTarget.owner :: Nil val (capturedParams, lambdaParams) = lambdaTarget.paramss.head.splitAt(lambdaTarget.paramss.head.length - arity) // Requires https://github.com/scala/scala-java8-compat on the runtime classpath - val returnUnit = lambdaTarget.info.resultType.typeSymbol == UnitClass - val functionalInterfaceDesc: String = classBTypeFromSymbol(functionalInterface).descriptor - val desc = (receiver.toList ::: capturedParams).map(sym => toTypeKind(sym.info)).mkString(("("), "", ")") + functionalInterfaceDesc + val invokedType = asm.Type.getMethodDescriptor(asmType(functionalInterface), (receiver ::: capturedParams).map(sym => toTypeKind(sym.info).toASMType): _*) - // TODO specialization val constrainedType = new MethodBType(lambdaParams.map(p => toTypeKind(p.tpe)), toTypeKind(lambdaTarget.tpe.resultType)).toASMType - val abstractMethod = functionalInterface.info.decls.find(_.isDeferred).getOrElse(functionalInterface.info.member(nme.apply)) - val methodName = abstractMethod.name.toString - val applyN = { - val mt = asmMethodType(abstractMethod) - mt.toASMType - } - - bc.jmethod.visitInvokeDynamicInsn(methodName, desc, lambdaMetaFactoryBootstrapHandle, - // boostrap args - applyN, targetHandle, constrainedType + val sam = functionalInterface.info.decls.find(_.isDeferred).getOrElse(functionalInterface.info.member(nme.apply)) + val samName = sam.name.toString + val samMethodType = asmMethodType(sam).toASMType + + val flags = 3 // TODO 2.12.x Replace with LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS + + val ScalaSerializable = classBTypeFromSymbol(definitions.SerializableClass).toASMType + bc.jmethod.visitInvokeDynamicInsn(samName, invokedType, lambdaMetaFactoryBootstrapHandle, + /* samMethodType = */ samMethodType, + /* implMethod = */ implMethodHandle, + /* instantiatedMethodType = */ constrainedType, + /* flags = */ flags.asInstanceOf[AnyRef], + /* markerInterfaceCount = */ 1.asInstanceOf[AnyRef], + /* markerInterfaces[0] = */ ScalaSerializable, + /* bridgeCount = */ 0.asInstanceOf[AnyRef] ) + indyLambdaHosts += this.claszSymbol } } - val lambdaMetaFactoryBootstrapHandle = + lazy val lambdaMetaFactoryBootstrapHandle = new asm.Handle(asm.Opcodes.H_INVOKESTATIC, - "java/lang/invoke/LambdaMetafactory", "metafactory", - "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;") + definitions.LambdaMetaFactory.fullName('/'), sn.AltMetafactory.toString, + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;") } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala index 18468f5ae3..6aa3a62295 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala @@ -682,6 +682,59 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters { new java.lang.Long(id) ).visitEnd() } + + /** + * Add: + * private static java.util.Map $deserializeLambdaCache$ = null + * private static Object $deserializeLambda$(SerializedLambda l) { + * var cache = $deserializeLambdaCache$ + * if (cache eq null) { + * cache = new java.util.HashMap() + * $deserializeLambdaCache$ = cache + * } + * return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l); + * } + */ + def addLambdaDeserialize(clazz: Symbol, jclass: asm.ClassVisitor): Unit = { + val cw = jclass + import scala.tools.asm.Opcodes._ + + // Need to force creation of BTypes for these as `getCommonSuperClass` is called on + // automatically computing the max stack size (`visitMaxs`) during method writing. + javaUtilHashMapReference + javaUtilMapReference + + cw.visitInnerClass("java/lang/invoke/MethodHandles$Lookup", "java/lang/invoke/MethodHandles", "Lookup", ACC_PUBLIC + ACC_FINAL + ACC_STATIC) + + { + val fv = cw.visitField(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambdaCache$", "Ljava/util/Map;", null, null) + fv.visitEnd() + } + + { + val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", "(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", null, null) + mv.visitCode() + mv.visitFieldInsn(GETSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;") + mv.visitVarInsn(ASTORE, 1) + mv.visitVarInsn(ALOAD, 1) + val l0 = new asm.Label() + mv.visitJumpInsn(IFNONNULL, l0) + mv.visitTypeInsn(NEW, "java/util/HashMap") + mv.visitInsn(DUP) + mv.visitMethodInsn(INVOKESPECIAL, "java/util/HashMap", "<init>", "()V", false) + mv.visitVarInsn(ASTORE, 1) + mv.visitVarInsn(ALOAD, 1) + mv.visitFieldInsn(PUTSTATIC, clazz.javaBinaryName.encoded, "$deserializeLambdaCache$", "Ljava/util/Map;") + mv.visitLabel(l0) + mv.visitFrame(asm.Opcodes.F_APPEND,1, Array("java/util/Map"), 0, null) + mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup", "()Ljava/lang/invoke/MethodHandles$Lookup;", false) + mv.visitVarInsn(ALOAD, 1) + mv.visitVarInsn(ALOAD, 0) + mv.visitMethodInsn(INVOKESTATIC, "scala/compat/java8/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false) + mv.visitInsn(ARETURN) + mv.visitEnd() + } + } } // end of trait BCClassGen /* functionality for building plain and mirror classes */ diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 2a06c62e37..a2fd22d24c 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -68,6 +68,8 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { var isCZStaticModule = false var isCZRemote = false + protected val indyLambdaHosts = collection.mutable.Set[Symbol]() + /* ---------------- idiomatic way to ask questions to typer ---------------- */ def paramTKs(app: Apply): List[BType] = { @@ -121,6 +123,16 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { innerClassBufferASM ++= classBType.info.get.nestedClasses gen(cd.impl) + + + val shouldAddLambdaDeserialize = ( + settings.target.value == "jvm-1.8" + && settings.Ydelambdafy.value == "method" + && indyLambdaHosts.contains(claszSymbol)) + + if (shouldAddLambdaDeserialize) + addLambdaDeserialize(claszSymbol, cnode) + addInnerClassesASM(cnode, innerClassBufferASM.toList) cnode.visitAttribute(classBType.inlineInfoAttribute.get) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala index 492fe3ae79..00ca096e59 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/CoreBTypes.scala @@ -114,6 +114,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: BTFS) { lazy val jioSerializableReference : ClassBType = classBTypeFromSymbol(JavaSerializableClass) // java/io/Serializable lazy val scalaSerializableReference : ClassBType = classBTypeFromSymbol(SerializableClass) // scala/Serializable lazy val classCastExceptionReference : ClassBType = classBTypeFromSymbol(ClassCastExceptionClass) // java/lang/ClassCastException + lazy val javaUtilMapReference : ClassBType = classBTypeFromSymbol(JavaUtilMap) // java/util/Map + lazy val javaUtilHashMapReference : ClassBType = classBTypeFromSymbol(JavaUtilHashMap) // java/util/HashMap lazy val srBooleanRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.BooleanRef]) lazy val srByteRef : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.ByteRef]) @@ -258,6 +260,8 @@ final class CoreBTypesProxy[BTFS <: BTypesFromSymbols[_ <: Global]](val bTypes: def jioSerializableReference : ClassBType = _coreBTypes.jioSerializableReference def scalaSerializableReference : ClassBType = _coreBTypes.scalaSerializableReference def classCastExceptionReference : ClassBType = _coreBTypes.classCastExceptionReference + def javaUtilMapReference : ClassBType = _coreBTypes.javaUtilMapReference + def javaUtilHashMapReference : ClassBType = _coreBTypes.javaUtilHashMapReference def srBooleanRef : ClassBType = _coreBTypes.srBooleanRef def srByteRef : ClassBType = _coreBTypes.srByteRef diff --git a/src/compiler/scala/tools/nsc/transform/Delambdafy.scala b/src/compiler/scala/tools/nsc/transform/Delambdafy.scala index 17fad78972..55ab73028e 100644 --- a/src/compiler/scala/tools/nsc/transform/Delambdafy.scala +++ b/src/compiler/scala/tools/nsc/transform/Delambdafy.scala @@ -146,6 +146,9 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre val isStatic = target.hasFlag(STATIC) def createBoxingBridgeMethod(functionParamTypes: List[Type], functionResultType: Type): Tree = { + // Note: we bail out of this method and return EmptyTree if we find there is no adaptation required. + // If we need to improve performance, we could check the types first before creating the + // method and parameter symbols. val methSym = oldClass.newMethod(target.name.append("$adapted").toTermName, target.pos, target.flags | FINAL | ARTIFACT) var neededAdaptation = false def boxedType(tpe: Type): Type = { diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index 73ffb267a9..f3dd6a3280 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -369,6 +369,8 @@ trait Definitions extends api.StandardDefinitions { lazy val JavaEnumClass = requiredClass[java.lang.Enum[_]] lazy val RemoteInterfaceClass = requiredClass[java.rmi.Remote] lazy val RemoteExceptionClass = requiredClass[java.rmi.RemoteException] + lazy val JavaUtilMap = requiredClass[java.util.Map[_, _]] + lazy val JavaUtilHashMap = requiredClass[java.util.HashMap[_, _]] lazy val ByNameParamClass = specialPolyClass(tpnme.BYNAME_PARAM_CLASS_NAME, COVARIANT)(_ => AnyTpe) lazy val JavaRepeatedParamClass = specialPolyClass(tpnme.JAVA_REPEATED_PARAM_CLASS_NAME, COVARIANT)(tparam => arrayType(tparam.tpe)) @@ -514,6 +516,7 @@ trait Definitions extends api.StandardDefinitions { lazy val ScalaSignatureAnnotation = requiredClass[scala.reflect.ScalaSignature] lazy val ScalaLongSignatureAnnotation = requiredClass[scala.reflect.ScalaLongSignature] + lazy val LambdaMetaFactory = getClassIfDefined("java.lang.invoke.LambdaMetafactory") lazy val MethodHandle = getClassIfDefined("java.lang.invoke.MethodHandle") // Option classes diff --git a/src/reflect/scala/reflect/internal/StdNames.scala b/src/reflect/scala/reflect/internal/StdNames.scala index c0562b0679..63e2ca0dbe 100644 --- a/src/reflect/scala/reflect/internal/StdNames.scala +++ b/src/reflect/scala/reflect/internal/StdNames.scala @@ -1167,6 +1167,8 @@ trait StdNames { final val Invoke: TermName = newTermName("invoke") final val InvokeExact: TermName = newTermName("invokeExact") + final val AltMetafactory: TermName = newTermName("altMetafactory") + val Boxed = immutable.Map[TypeName, TypeName]( tpnme.Boolean -> BoxedBoolean, tpnme.Byte -> BoxedByte, diff --git a/src/reflect/scala/reflect/internal/transform/PostErasure.scala b/src/reflect/scala/reflect/internal/transform/PostErasure.scala index 466c6133b2..dd4f044818 100644 --- a/src/reflect/scala/reflect/internal/transform/PostErasure.scala +++ b/src/reflect/scala/reflect/internal/transform/PostErasure.scala @@ -9,8 +9,7 @@ trait PostErasure { object elimErasedValueType extends TypeMap { def apply(tp: Type) = tp match { case ConstantType(Constant(tp: Type)) => ConstantType(Constant(apply(tp))) - case ErasedValueType(_, underlying) => - underlying + case ErasedValueType(_, underlying) => underlying case _ => mapOver(tp) } } diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala index 1c0aa7cf6d..ea213cadd9 100644 --- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala +++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala @@ -255,6 +255,8 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => definitions.JavaEnumClass definitions.RemoteInterfaceClass definitions.RemoteExceptionClass + definitions.JavaUtilMap + definitions.JavaUtilHashMap definitions.ByNameParamClass definitions.JavaRepeatedParamClass definitions.RepeatedParamClass @@ -310,6 +312,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => definitions.QuasiquoteClass_api_unapply definitions.ScalaSignatureAnnotation definitions.ScalaLongSignatureAnnotation + definitions.LambdaMetaFactory definitions.MethodHandle definitions.OptionClass definitions.OptionModule diff --git a/test/files/run/lambda-serialization-gc.scala b/test/files/run/lambda-serialization-gc.scala new file mode 100644 index 0000000000..8fa0b4b402 --- /dev/null +++ b/test/files/run/lambda-serialization-gc.scala @@ -0,0 +1,40 @@ +import java.io._ + +import java.net.URLClassLoader + +class C { + def serializeDeserialize[T <: AnyRef](obj: T) = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } + + serializeDeserialize((c: String) => c.length) +} + +object Test { + def main(args: Array[String]): Unit = { + test() + } + + def test(): Unit = { + val loader = getClass.getClassLoader.asInstanceOf[URLClassLoader] + val loaderCClass = classOf[C] + def deserializedInThrowawayClassloader = { + val throwawayLoader: java.net.URLClassLoader = new java.net.URLClassLoader(loader.getURLs, ClassLoader.getSystemClassLoader) { + val maxMemory = Runtime.getRuntime.maxMemory() + val junk = new Array[Byte]((maxMemory / 2).toInt) + } + val clazz = throwawayLoader.loadClass("C") + assert(clazz != loaderCClass) + clazz.newInstance() + } + (1 to 4) foreach { i => + // This would OOM by the third iteration if we leaked `throwawayLoader` during + // deserialization. + deserializedInThrowawayClassloader + } + } +} diff --git a/test/files/run/lambda-serialization.scala b/test/files/run/lambda-serialization.scala new file mode 100644 index 0000000000..46b26d7c5e --- /dev/null +++ b/test/files/run/lambda-serialization.scala @@ -0,0 +1,35 @@ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, ByteArrayOutputStream} + +object Test { + def main(args: Array[String]): Unit = { + roundTrip + } + + def roundTrip(): Unit = { + val c = new Capture("Capture") + val lambda = (p: Param) => ("a", p, c) + val reconstituted1 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + val p = new Param + assert(reconstituted1.apply(p) == ("a", p, c)) + val reconstituted2 = serializeDeserialize(lambda).asInstanceOf[Object => Any] + assert(reconstituted1.getClass == reconstituted2.getClass) + + val reconstituted3 = serializeDeserialize(reconstituted1) + assert(reconstituted3.apply(p) == ("a", p, c)) + + val specializedLambda = (p: Int) => List(p, c).length + assert(serializeDeserialize(specializedLambda).apply(42) == 2) + assert(serializeDeserialize(serializeDeserialize(specializedLambda)).apply(42) == 2) + } + + def serializeDeserialize[T <: AnyRef](obj: T) = { + val buffer = new ByteArrayOutputStream + val out = new ObjectOutputStream(buffer) + out.writeObject(obj) + val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray)) + in.readObject.asInstanceOf[T] + } +} + +case class Capture(s: String) extends Serializable +class Param |