aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala305
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala571
3 files changed, 831 insertions, 58 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index e3f52f6ff1..4ac0382d80 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -19,17 +19,20 @@ package org.apache.spark.util
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import scala.collection.mutable.Map
-import scala.collection.mutable.Set
+import scala.collection.mutable.{Map, Set}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
import org.apache.spark.{Logging, SparkEnv, SparkException}
+/**
+ * A cleaner that renders closures serializable if they can be done so safely.
+ */
private[spark] object ClosureCleaner extends Logging {
+
// Get an ASM class reader for a given class from the JAR that loaded it
- private def getClassReader(cls: Class[_]): ClassReader = {
+ private[util] def getClassReader(cls: Class[_]): ClassReader = {
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
@@ -55,10 +58,14 @@ private[spark] object ClosureCleaner extends Logging {
private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
- if (isClosure(f.getType)) {
- return f.getType :: getOuterClasses(f.get(obj))
- } else {
- return f.getType :: Nil // Stop at the first $outer that is not a closure
+ val outer = f.get(obj)
+ // The outer pointer may be null if we have cleaned this closure before
+ if (outer != null) {
+ if (isClosure(f.getType)) {
+ return f.getType :: getOuterClasses(outer)
+ } else {
+ return f.getType :: Nil // Stop at the first $outer that is not a closure
+ }
}
}
Nil
@@ -68,16 +75,23 @@ private[spark] object ClosureCleaner extends Logging {
private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
- if (isClosure(f.getType)) {
- return f.get(obj) :: getOuterObjects(f.get(obj))
- } else {
- return f.get(obj) :: Nil // Stop at the first $outer that is not a closure
+ val outer = f.get(obj)
+ // The outer pointer may be null if we have cleaned this closure before
+ if (outer != null) {
+ if (isClosure(f.getType)) {
+ return outer :: getOuterObjects(outer)
+ } else {
+ return outer :: Nil // Stop at the first $outer that is not a closure
+ }
}
}
Nil
}
- private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
+ /**
+ * Return a list of classes that represent closures enclosed in the given closure object.
+ */
+ private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
var stack = List[Class[_]](obj.getClass)
while (!stack.isEmpty) {
@@ -90,7 +104,7 @@ private[spark] object ClosureCleaner extends Logging {
stack = cls :: stack
}
}
- return (seen - obj.getClass).toList
+ (seen - obj.getClass).toList
}
private def createNullValue(cls: Class[_]): AnyRef = {
@@ -101,21 +115,124 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- def clean(func: AnyRef, checkSerializable: Boolean = true) {
+ /**
+ * Clean the given closure in place.
+ *
+ * More specifically, this renders the given closure serializable as long as it does not
+ * explicitly reference unserializable objects.
+ *
+ * @param closure the closure to clean
+ * @param checkSerializable whether to verify that the closure is serializable after cleaning
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ */
+ def clean(
+ closure: AnyRef,
+ checkSerializable: Boolean = true,
+ cleanTransitively: Boolean = true): Unit = {
+ clean(closure, checkSerializable, cleanTransitively, Map.empty)
+ }
+
+ /**
+ * Helper method to clean the given closure in place.
+ *
+ * The mechanism is to traverse the hierarchy of enclosing closures and null out any
+ * references along the way that are not actually used by the starting closure, but are
+ * nevertheless included in the compiled anonymous classes. Note that it is unsafe to
+ * simply mutate the enclosing closures in place, as other code paths may depend on them.
+ * Instead, we clone each enclosing closure and set the parent pointers accordingly.
+ *
+ * By default, closures are cleaned transitively. This means we detect whether enclosing
+ * objects are actually referenced by the starting one, either directly or transitively,
+ * and, if not, sever these closures from the hierarchy. In other words, in addition to
+ * nulling out unused field references, we also null out any parent pointers that refer
+ * to enclosing objects not actually needed by the starting closure. We determine
+ * transitivity by tracing through the tree of all methods ultimately invoked by the
+ * inner closure and record all the fields referenced in the process.
+ *
+ * For instance, transitive cleaning is necessary in the following scenario:
+ *
+ * class SomethingNotSerializable {
+ * def someValue = 1
+ * def scope(name: String)(body: => Unit) = body
+ * def someMethod(): Unit = scope("one") {
+ * def x = someValue
+ * def y = 2
+ * scope("two") { println(y + 1) }
+ * }
+ * }
+ *
+ * In this example, scope "two" is not serializable because it references scope "one", which
+ * references SomethingNotSerializable. Note that, however, the body of scope "two" does not
+ * actually depend on SomethingNotSerializable. This means we can safely null out the parent
+ * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two"
+ * no longer references SomethingNotSerializable transitively.
+ *
+ * @param func the starting closure to clean
+ * @param checkSerializable whether to verify that the closure is serializable after cleaning
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ * @param accessedFields a map from a class to a set of its fields that are accessed by
+ * the starting closure
+ */
+ private def clean(
+ func: AnyRef,
+ checkSerializable: Boolean,
+ cleanTransitively: Boolean,
+ accessedFields: Map[Class[_], Set[String]]): Unit = {
+
+ // TODO: clean all inner closures first. This requires us to find the inner objects.
// TODO: cache outerClasses / innerClasses / accessedFields
+
+ if (func == null) {
+ return
+ }
+
+ logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++")
+
+ // A list of classes that represents closures enclosed in the given one
+ val innerClasses = getInnerClosureClasses(func)
+
+ // A list of enclosing objects and their respective classes, from innermost to outermost
+ // An outer object at a given index is of type outer class at the same index
val outerClasses = getOuterClasses(func)
- val innerClasses = getInnerClasses(func)
val outerObjects = getOuterObjects(func)
- val accessedFields = Map[Class[_], Set[String]]()
-
+ // For logging purposes only
+ val declaredFields = func.getClass.getDeclaredFields
+ val declaredMethods = func.getClass.getDeclaredMethods
+
+ logDebug(" + declared fields: " + declaredFields.size)
+ declaredFields.foreach { f => logDebug(" " + f) }
+ logDebug(" + declared methods: " + declaredMethods.size)
+ declaredMethods.foreach { m => logDebug(" " + m) }
+ logDebug(" + inner classes: " + innerClasses.size)
+ innerClasses.foreach { c => logDebug(" " + c.getName) }
+ logDebug(" + outer classes: " + outerClasses.size)
+ outerClasses.foreach { c => logDebug(" " + c.getName) }
+ logDebug(" + outer objects: " + outerObjects.size)
+ outerObjects.foreach { o => logDebug(" " + o) }
+
+ // Fail fast if we detect return statements in closures
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
-
- for (cls <- outerClasses)
- accessedFields(cls) = Set[String]()
- for (cls <- func.getClass :: innerClasses)
- getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
- // logInfo("accessedFields: " + accessedFields)
+
+ // If accessed fields is not populated yet, we assume that
+ // the closure we are trying to clean is the starting one
+ if (accessedFields.isEmpty) {
+ logDebug(s" + populating accessed fields because this is the starting closure")
+ // Initialize accessed fields with the outer classes first
+ // This step is needed to associate the fields to the correct classes later
+ for (cls <- outerClasses) {
+ accessedFields(cls) = Set[String]()
+ }
+ // Populate accessed fields by visiting all fields and methods accessed by this and
+ // all of its inner closures. If transitive cleaning is enabled, this may recursively
+ // visits methods that belong to other classes in search of transitively referenced fields.
+ for (cls <- func.getClass :: innerClasses) {
+ getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0)
+ }
+ }
+
+ logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
+ accessedFields.foreach { f => logDebug(" " + f) }
val inInterpreter = {
try {
@@ -126,34 +243,68 @@ private[spark] object ClosureCleaner extends Logging {
}
}
+ // List of outer (class, object) pairs, ordered from outermost to innermost
+ // Note that all outer objects but the outermost one (first one in this list) must be closures
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
- var outer: AnyRef = null
+ var parent: AnyRef = null
if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
// The closure is ultimately nested inside a class; keep the object of that
// class without cloning it since we don't want to clone the user's objects.
- outer = outerPairs.head._2
+ // Note that we still need to keep around the outermost object itself because
+ // we need it to clone its child closure later (see below).
+ logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}")
+ parent = outerPairs.head._2 // e.g. SparkContext
outerPairs = outerPairs.tail
+ } else if (outerPairs.size > 0) {
+ logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}")
+ } else {
+ logDebug(" + there are no enclosing objects!")
}
+
// Clone the closure objects themselves, nulling out any fields that are not
// used in the closure we're working on or any of its inner closures.
for ((cls, obj) <- outerPairs) {
- outer = instantiateClass(cls, outer, inInterpreter)
+ logDebug(s" + cloning the object $obj of class ${cls.getName}")
+ // We null out these unused references by cloning each object and then filling in all
+ // required fields from the original object. We need the parent here because the Java
+ // language specification requires the first constructor parameter of any closure to be
+ // its enclosing object.
+ val clone = instantiateClass(cls, parent, inInterpreter)
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
- // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
- field.set(outer, value)
+ field.set(clone, value)
+ }
+ // If transitive cleaning is enabled, we recursively clean any enclosing closure using
+ // the already populated accessed fields map of the starting closure
+ if (cleanTransitively && isClosure(clone.getClass)) {
+ logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})")
+ // No need to check serializable here for the outer closures because we're
+ // only interested in the serializability of the starting closure
+ clean(clone, checkSerializable = false, cleanTransitively, accessedFields)
}
+ parent = clone
}
- if (outer != null) {
- // logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
+ // Update the parent pointer ($outer) of this closure
+ if (parent != null) {
val field = func.getClass.getDeclaredField("$outer")
field.setAccessible(true)
- field.set(func, outer)
+ // If the starting closure doesn't actually need our enclosing object, then just null it out
+ if (accessedFields.contains(func.getClass) &&
+ !accessedFields(func.getClass).contains("$outer")) {
+ logDebug(s" + the starting closure doesn't actually need $parent, so we null it out")
+ field.set(func, null)
+ } else {
+ // Update this closure's parent pointer to point to our enclosing object,
+ // which could either be a cloned closure or the original user object
+ field.set(func, parent)
+ }
}
-
+
+ logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++")
+
if (checkSerializable) {
ensureSerializable(func)
}
@@ -167,15 +318,17 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
- // logInfo("Creating a " + cls + " with outer = " + outer)
+ private def instantiateClass(
+ cls: Class[_],
+ enclosingObject: AnyRef,
+ inInterpreter: Boolean): AnyRef = {
if (!inInterpreter) {
// This is a bona fide closure class, whose constructor has no effects
// other than to set its fields, so use its constructor
val cons = cls.getConstructors()(0)
val params = cons.getParameterTypes.map(createNullValue).toArray
- if (outer != null) {
- params(0) = outer // First param is always outer object
+ if (enclosingObject != null) {
+ params(0) = enclosingObject // First param is always enclosing object
}
return cons.newInstance(params: _*).asInstanceOf[AnyRef]
} else {
@@ -184,19 +337,17 @@ private[spark] object ClosureCleaner extends Logging {
val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
val obj = newCtor.newInstance().asInstanceOf[AnyRef]
- if (outer != null) {
- // logInfo("3: Setting $outer on " + cls + " to " + outer);
+ if (enclosingObject != null) {
val field = cls.getDeclaredField("$outer")
field.setAccessible(true)
- field.set(obj, outer)
+ field.set(obj, enclosingObject)
}
obj
}
}
}
-private[spark]
-class ReturnStatementFinder extends ClassVisitor(ASM4) {
+private class ReturnStatementFinder extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
if (name.contains("apply")) {
@@ -213,26 +364,65 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
}
}
-private[spark]
-class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
- override def visitMethod(access: Int, name: String, desc: String,
- sig: String, exceptions: Array[String]): MethodVisitor = {
+/** Helper class to identify a method. */
+private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)
+
+/**
+ * Find the fields accessed by a given class.
+ *
+ * The resulting fields are stored in the mutable map passed in through the constructor.
+ * This map is assumed to have its keys already populated with the classes of interest.
+ *
+ * @param fields the mutable map that stores the fields to return
+ * @param findTransitively if true, find fields indirectly referenced through method calls
+ * @param specificMethod if not empty, visit only this specific method
+ * @param visitedMethods a set of visited methods to avoid cycles
+ */
+private[util] class FieldAccessFinder(
+ fields: Map[Class[_], Set[String]],
+ findTransitively: Boolean,
+ specificMethod: Option[MethodIdentifier[_]] = None,
+ visitedMethods: Set[MethodIdentifier[_]] = Set.empty)
+ extends ClassVisitor(ASM4) {
+
+ override def visitMethod(
+ access: Int,
+ name: String,
+ desc: String,
+ sig: String,
+ exceptions: Array[String]): MethodVisitor = {
+
+ // If we are told to visit only a certain method and this is not the one, ignore it
+ if (specificMethod.isDefined &&
+ (specificMethod.get.name != name || specificMethod.get.desc != desc)) {
+ return null
+ }
+
new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
- for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
- output(cl) += name
+ for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+ fields(cl) += name
}
}
}
- override def visitMethodInsn(op: Int, owner: String, name: String,
- desc: String) {
- // Check for calls a getter method for a variable in an interpreter wrapper object.
- // This means that the corresponding field will be accessed, so we should save it.
- if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
- for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
- output(cl) += name
+ override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+ for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+ // Check for calls a getter method for a variable in an interpreter wrapper object.
+ // This means that the corresponding field will be accessed, so we should save it.
+ if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
+ fields(cl) += name
+ }
+ // Optionally visit other methods to find fields that are transitively referenced
+ if (findTransitively) {
+ val m = MethodIdentifier(cl, name, desc)
+ if (!visitedMethods.contains(m)) {
+ // Keep track of visited methods to avoid potential infinite cycles
+ visitedMethods += m
+ ClosureCleaner.getClassReader(cl).accept(
+ new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0)
+ }
}
}
}
@@ -240,9 +430,14 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor
}
}
-private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
+private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
+ // TODO: Recursively find inner closures that we indirectly reference, e.g.
+ // val closure1 = () = { () => 1 }
+ // val closure2 = () => { (1 to 5).map(closure1) }
+ // The second closure technically has two inner closures, but this finder only finds one
+
override def visit(version: Int, access: Int, name: String, sig: String,
superName: String, interfaces: Array[String]) {
myName = name
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index c47162779b..ff1bfe0774 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -50,7 +50,7 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}
-
+
test("toplevel return statements in closures are identified at cleaning time") {
val ex = intercept[SparkException] {
TestObjectWithBogusReturns.run()
@@ -61,13 +61,20 @@ class ClosureCleanerSuite extends FunSuite {
test("return statements from named functions nested in closures don't raise exceptions") {
val result = TestObjectWithNestedReturns.run()
- assert(result == 1)
+ assert(result === 1)
}
}
// A non-serializable class we create in closures to make sure that we aren't
// keeping references to unneeded variables from our outer closures.
-class NonSerializable {}
+class NonSerializable(val id: Int = -1) {
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: NonSerializable => id == o.id
+ case _ => false
+ }
+ }
+}
object TestObject {
def run(): Int = {
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
new file mode 100644
index 0000000000..59456790e8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.NotSerializableException
+
+import scala.collection.mutable
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester}
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.serializer.SerializerInstance
+
+/**
+ * Another test suite for the closure cleaner that is finer-grained.
+ * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}.
+ */
+class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
+
+ // Start a SparkContext so that the closure serializer is accessible
+ // We do not actually use this explicitly otherwise
+ private var sc: SparkContext = null
+ private var closureSerializer: SerializerInstance = null
+
+ override def beforeAll(): Unit = {
+ sc = new SparkContext("local", "test")
+ closureSerializer = sc.env.closureSerializer.newInstance()
+ }
+
+ override def afterAll(): Unit = {
+ sc.stop()
+ sc = null
+ closureSerializer = null
+ }
+
+ // Some fields and methods to reference in inner closures later
+ private val someSerializableValue = 1
+ private val someNonSerializableValue = new NonSerializable
+ private def someSerializableMethod() = 1
+ private def someNonSerializableMethod() = new NonSerializable
+
+ /** Assert that the given closure is serializable (or not). */
+ private def assertSerializable(closure: AnyRef, serializable: Boolean): Unit = {
+ if (serializable) {
+ closureSerializer.serialize(closure)
+ } else {
+ intercept[NotSerializableException] {
+ closureSerializer.serialize(closure)
+ }
+ }
+ }
+
+ /**
+ * Helper method for testing whether closure cleaning works as expected.
+ * This cleans the given closure twice, with and without transitive cleaning.
+ *
+ * @param closure closure to test cleaning with
+ * @param serializableBefore if true, verify that the closure is serializable
+ * before cleaning, otherwise assert that it is not
+ * @param serializableAfter if true, assert that the closure is serializable
+ * after cleaning otherwise assert that it is not
+ */
+ private def verifyCleaning(
+ closure: AnyRef,
+ serializableBefore: Boolean,
+ serializableAfter: Boolean): Unit = {
+ verifyCleaning(closure, serializableBefore, serializableAfter, transitive = true)
+ verifyCleaning(closure, serializableBefore, serializableAfter, transitive = false)
+ }
+
+ /** Helper method for testing whether closure cleaning works as expected. */
+ private def verifyCleaning(
+ closure: AnyRef,
+ serializableBefore: Boolean,
+ serializableAfter: Boolean,
+ transitive: Boolean): Unit = {
+ assertSerializable(closure, serializableBefore)
+ // If the resulting closure is not serializable even after
+ // cleaning, we expect ClosureCleaner to throw a SparkException
+ if (serializableAfter) {
+ ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+ } else {
+ intercept[SparkException] {
+ ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+ }
+ }
+ assertSerializable(closure, serializableAfter)
+ }
+
+ /**
+ * Return the fields accessed by the given closure by class.
+ * This also optionally finds the fields transitively referenced through methods invocations.
+ */
+ private def findAccessedFields(
+ closure: AnyRef,
+ outerClasses: Seq[Class[_]],
+ findTransitively: Boolean): Map[Class[_], Set[String]] = {
+ val fields = new mutable.HashMap[Class[_], mutable.Set[String]]
+ outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] }
+ ClosureCleaner.getClassReader(closure.getClass)
+ .accept(new FieldAccessFinder(fields, findTransitively), 0)
+ fields.mapValues(_.toSet).toMap
+ }
+
+ // Accessors for private methods
+ private val _isClosure = PrivateMethod[Boolean]('isClosure)
+ private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses)
+ private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses)
+ private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)
+
+ private def isClosure(obj: AnyRef): Boolean = {
+ ClosureCleaner invokePrivate _isClosure(obj)
+ }
+
+ private def getInnerClosureClasses(closure: AnyRef): List[Class[_]] = {
+ ClosureCleaner invokePrivate _getInnerClosureClasses(closure)
+ }
+
+ private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
+ ClosureCleaner invokePrivate _getOuterClasses(closure)
+ }
+
+ private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
+ ClosureCleaner invokePrivate _getOuterObjects(closure)
+ }
+
+ test("get inner closure classes") {
+ val closure1 = () => 1
+ val closure2 = () => { () => 1 }
+ val closure3 = (i: Int) => {
+ (1 to i).map { x => x + 1 }.filter { x => x > 5 }
+ }
+ val closure4 = (j: Int) => {
+ (1 to j).flatMap { x =>
+ (1 to x).flatMap { y =>
+ (1 to y).map { z => z + 1 }
+ }
+ }
+ }
+ val inner1 = getInnerClosureClasses(closure1)
+ val inner2 = getInnerClosureClasses(closure2)
+ val inner3 = getInnerClosureClasses(closure3)
+ val inner4 = getInnerClosureClasses(closure4)
+ assert(inner1.isEmpty)
+ assert(inner2.size === 1)
+ assert(inner3.size === 2)
+ assert(inner4.size === 3)
+ assert(inner2.forall(isClosure))
+ assert(inner3.forall(isClosure))
+ assert(inner4.forall(isClosure))
+ }
+
+ test("get outer classes and objects") {
+ val localValue = someSerializableValue
+ val closure1 = () => 1
+ val closure2 = () => localValue
+ val closure3 = () => someSerializableValue
+ val closure4 = () => someSerializableMethod()
+ val outerClasses1 = getOuterClasses(closure1)
+ val outerClasses2 = getOuterClasses(closure2)
+ val outerClasses3 = getOuterClasses(closure3)
+ val outerClasses4 = getOuterClasses(closure4)
+ val outerObjects1 = getOuterObjects(closure1)
+ val outerObjects2 = getOuterObjects(closure2)
+ val outerObjects3 = getOuterObjects(closure3)
+ val outerObjects4 = getOuterObjects(closure4)
+
+ // The classes and objects should have the same size
+ assert(outerClasses1.size === outerObjects1.size)
+ assert(outerClasses2.size === outerObjects2.size)
+ assert(outerClasses3.size === outerObjects3.size)
+ assert(outerClasses4.size === outerObjects4.size)
+
+ // These do not have $outer pointers because they reference only local variables
+ assert(outerClasses1.isEmpty)
+ assert(outerClasses2.isEmpty)
+
+ // These closures do have $outer pointers because they ultimately reference `this`
+ // The first $outer pointer refers to the closure defines this test (see FunSuite#test)
+ // The second $outer pointer refers to ClosureCleanerSuite2
+ assert(outerClasses3.size === 2)
+ assert(outerClasses4.size === 2)
+ assert(isClosure(outerClasses3(0)))
+ assert(isClosure(outerClasses4(0)))
+ assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope
+ assert(outerClasses3(1) === this.getClass)
+ assert(outerClasses4(1) === this.getClass)
+ assert(outerObjects3(1) === this)
+ assert(outerObjects4(1) === this)
+ }
+
+ test("get outer classes and objects with nesting") {
+ val localValue = someSerializableValue
+
+ val test1 = () => {
+ val x = 1
+ val closure1 = () => 1
+ val closure2 = () => x
+ val outerClasses1 = getOuterClasses(closure1)
+ val outerClasses2 = getOuterClasses(closure2)
+ val outerObjects1 = getOuterObjects(closure1)
+ val outerObjects2 = getOuterObjects(closure2)
+ assert(outerClasses1.size === outerObjects1.size)
+ assert(outerClasses2.size === outerObjects2.size)
+ // These inner closures only reference local variables, and so do not have $outer pointers
+ assert(outerClasses1.isEmpty)
+ assert(outerClasses2.isEmpty)
+ }
+
+ val test2 = () => {
+ def y = 1
+ val closure1 = () => 1
+ val closure2 = () => y
+ val closure3 = () => localValue
+ val outerClasses1 = getOuterClasses(closure1)
+ val outerClasses2 = getOuterClasses(closure2)
+ val outerClasses3 = getOuterClasses(closure3)
+ val outerObjects1 = getOuterObjects(closure1)
+ val outerObjects2 = getOuterObjects(closure2)
+ val outerObjects3 = getOuterObjects(closure3)
+ assert(outerClasses1.size === outerObjects1.size)
+ assert(outerClasses2.size === outerObjects2.size)
+ assert(outerClasses3.size === outerObjects3.size)
+ // Same as above, this closure only references local variables
+ assert(outerClasses1.isEmpty)
+ // This closure references the "test2" scope because it needs to find the method `y`
+ // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
+ assert(outerClasses2.size === 3)
+ // This closure references the "test2" scope because it needs to find the `localValue`
+ // defined outside of this scope
+ assert(outerClasses3.size === 3)
+ assert(isClosure(outerClasses2(0)))
+ assert(isClosure(outerClasses3(0)))
+ assert(isClosure(outerClasses2(1)))
+ assert(isClosure(outerClasses3(1)))
+ assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope
+ assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope
+ assert(outerClasses2(2) === this.getClass)
+ assert(outerClasses3(2) === this.getClass)
+ assert(outerObjects2(2) === this)
+ assert(outerObjects3(2) === this)
+ }
+
+ test1()
+ test2()
+ }
+
+ test("find accessed fields") {
+ val localValue = someSerializableValue
+ val closure1 = () => 1
+ val closure2 = () => localValue
+ val closure3 = () => someSerializableValue
+ val outerClasses1 = getOuterClasses(closure1)
+ val outerClasses2 = getOuterClasses(closure2)
+ val outerClasses3 = getOuterClasses(closure3)
+
+ val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
+ val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
+ val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false)
+ assert(fields1.isEmpty)
+ assert(fields2.isEmpty)
+ assert(fields3.size === 2)
+ // This corresponds to the "FunSuite#test" closure. This is empty because the
+ // `someSerializableValue` belongs to its parent (i.e. ClosureCleanerSuite2).
+ assert(fields3(outerClasses3(0)).isEmpty)
+ // This corresponds to the ClosureCleanerSuite2. This is also empty, however,
+ // because accessing a `ClosureCleanerSuite2#someSerializableValue` actually involves a
+ // method call. Since we do not find fields transitively, we will not recursively trace
+ // through the fields referenced by this method.
+ assert(fields3(outerClasses3(1)).isEmpty)
+
+ val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true)
+ val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true)
+ val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true)
+ assert(fields1t.isEmpty)
+ assert(fields2t.isEmpty)
+ assert(fields3t.size === 2)
+ // Because we find fields transitively now, we are able to detect that we need the
+ // $outer pointer to get the field from the ClosureCleanerSuite2
+ assert(fields3t(outerClasses3(0)).size === 1)
+ assert(fields3t(outerClasses3(0)).head === "$outer")
+ assert(fields3t(outerClasses3(1)).size === 1)
+ assert(fields3t(outerClasses3(1)).head.contains("someSerializableValue"))
+ }
+
+ test("find accessed fields with nesting") {
+ val localValue = someSerializableValue
+
+ val test1 = () => {
+ def a = localValue + 1
+ val closure1 = () => 1
+ val closure2 = () => a
+ val closure3 = () => localValue
+ val closure4 = () => someSerializableValue
+ val outerClasses1 = getOuterClasses(closure1)
+ val outerClasses2 = getOuterClasses(closure2)
+ val outerClasses3 = getOuterClasses(closure3)
+ val outerClasses4 = getOuterClasses(closure4)
+
+ // First, find only fields accessed directly, not transitively, by these closures
+ val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
+ val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
+ val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively = false)
+ val fields4 = findAccessedFields(closure4, outerClasses4, findTransitively = false)
+ assert(fields1.isEmpty)
+ // Note that the size here represents the number of outer classes, not the number of fields
+ // "test1" < parameter of "FunSuite#test" < ClosureCleanerSuite2
+ assert(fields2.size === 3)
+ // Since we do not find fields transitively here, we do not look into what `def a` references
+ assert(fields2(outerClasses2(0)).isEmpty) // This corresponds to the "test1" scope
+ assert(fields2(outerClasses2(1)).isEmpty) // This corresponds to the "FunSuite#test" scope
+ assert(fields2(outerClasses2(2)).isEmpty) // This corresponds to the ClosureCleanerSuite2
+ assert(fields3.size === 3)
+ // Note that `localValue` is a field of the "test1" scope because `def a` references it,
+ // but NOT a field of the "FunSuite#test" scope because it is only a local variable there
+ assert(fields3(outerClasses3(0)).size === 1)
+ assert(fields3(outerClasses3(0)).head.contains("localValue"))
+ assert(fields3(outerClasses3(1)).isEmpty)
+ assert(fields3(outerClasses3(2)).isEmpty)
+ assert(fields4.size === 3)
+ // Because `val someSerializableValue` is an instance variable, even an explicit reference
+ // here actually involves a method call to access the underlying value of the variable.
+ // Because we are not finding fields transitively here, we do not consider the fields
+ // accessed by this "method" (i.e. the val's accessor).
+ assert(fields4(outerClasses4(0)).isEmpty)
+ assert(fields4(outerClasses4(1)).isEmpty)
+ assert(fields4(outerClasses4(2)).isEmpty)
+
+ // Now do the same, but find fields that the closures transitively reference
+ val fields1t = findAccessedFields(closure1, outerClasses1, findTransitively = true)
+ val fields2t = findAccessedFields(closure2, outerClasses2, findTransitively = true)
+ val fields3t = findAccessedFields(closure3, outerClasses3, findTransitively = true)
+ val fields4t = findAccessedFields(closure4, outerClasses4, findTransitively = true)
+ assert(fields1t.isEmpty)
+ assert(fields2t.size === 3)
+ assert(fields2t(outerClasses2(0)).size === 1) // `def a` references `localValue`
+ assert(fields2t(outerClasses2(0)).head.contains("localValue"))
+ assert(fields2t(outerClasses2(1)).isEmpty)
+ assert(fields2t(outerClasses2(2)).isEmpty)
+ assert(fields3t.size === 3)
+ assert(fields3t(outerClasses3(0)).size === 1) // as before
+ assert(fields3t(outerClasses3(0)).head.contains("localValue"))
+ assert(fields3t(outerClasses3(1)).isEmpty)
+ assert(fields3t(outerClasses3(2)).isEmpty)
+ assert(fields4t.size === 3)
+ // Through a series of method calls, we are able to detect that we ultimately access
+ // ClosureCleanerSuite2's field `someSerializableValue`. Along the way, we also accessed
+ // a few $outer parent pointers to get to the outermost object.
+ assert(fields4t(outerClasses4(0)) === Set("$outer"))
+ assert(fields4t(outerClasses4(1)) === Set("$outer"))
+ assert(fields4t(outerClasses4(2)).size === 1)
+ assert(fields4t(outerClasses4(2)).head.contains("someSerializableValue"))
+ }
+
+ test1()
+ }
+
+ test("clean basic serializable closures") {
+ val localValue = someSerializableValue
+ val closure1 = () => 1
+ val closure2 = () => Array[String]("a", "b", "c")
+ val closure3 = (s: String, arr: Array[Long]) => s + arr.mkString(", ")
+ val closure4 = () => localValue
+ val closure5 = () => new NonSerializable(5) // we're just serializing the class information
+ val closure1r = closure1()
+ val closure2r = closure2()
+ val closure3r = closure3("g", Array(1, 5, 8))
+ val closure4r = closure4()
+ val closure5r = closure5()
+
+ verifyCleaning(closure1, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure2, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure3, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure4, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure5, serializableBefore = true, serializableAfter = true)
+
+ // Verify that closures can still be invoked and the result still the same
+ assert(closure1() === closure1r)
+ assert(closure2() === closure2r)
+ assert(closure3("g", Array(1, 5, 8)) === closure3r)
+ assert(closure4() === closure4r)
+ assert(closure5() === closure5r)
+ }
+
+ test("clean basic non-serializable closures") {
+ val closure1 = () => this // ClosureCleanerSuite2 is not serializable
+ val closure5 = () => someSerializableValue
+ val closure3 = () => someSerializableMethod()
+ val closure4 = () => someNonSerializableValue
+ val closure2 = () => someNonSerializableMethod()
+
+ // These are not cleanable because they ultimately reference the ClosureCleanerSuite2
+ verifyCleaning(closure1, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure2, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure3, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure4, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure5, serializableBefore = false, serializableAfter = false)
+ }
+
+ test("clean basic nested serializable closures") {
+ val localValue = someSerializableValue
+ val closure1 = (i: Int) => {
+ (1 to i).map { x => x + localValue } // 1 level of nesting
+ }
+ val closure2 = (j: Int) => {
+ (1 to j).flatMap { x =>
+ (1 to x).map { y => y + localValue } // 2 levels
+ }
+ }
+ val closure3 = (k: Int, l: Int, m: Int) => {
+ (1 to k).flatMap(closure2) ++ // 4 levels
+ (1 to l).flatMap(closure1) ++ // 3 levels
+ (1 to m).map { x => x + 1 } // 2 levels
+ }
+ val closure1r = closure1(1)
+ val closure2r = closure2(2)
+ val closure3r = closure3(3, 4, 5)
+
+ verifyCleaning(closure1, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure2, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure3, serializableBefore = true, serializableAfter = true)
+
+ // Verify that closures can still be invoked and the result still the same
+ assert(closure1(1) === closure1r)
+ assert(closure2(2) === closure2r)
+ assert(closure3(3, 4, 5) === closure3r)
+ }
+
+ test("clean basic nested non-serializable closures") {
+ def localSerializableMethod(): Int = someSerializableValue
+ val localNonSerializableValue = someNonSerializableValue
+ // These closures ultimately reference the ClosureCleanerSuite2
+ // Note that even accessing `val` that is an instance variable involves a method call
+ val closure1 = (i: Int) => { (1 to i).map { x => x + someSerializableValue } }
+ val closure2 = (j: Int) => { (1 to j).map { x => x + someSerializableMethod() } }
+ val closure4 = (k: Int) => { (1 to k).map { x => x + localSerializableMethod() } }
+ // This closure references a local non-serializable value
+ val closure3 = (l: Int) => { (1 to l).map { x => localNonSerializableValue } }
+ // This is non-serializable no matter how many levels we nest it
+ val closure5 = (m: Int) => {
+ (1 to m).foreach { x =>
+ (1 to x).foreach { y =>
+ (1 to y).foreach { z =>
+ someSerializableValue
+ }
+ }
+ }
+ }
+
+ verifyCleaning(closure1, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure2, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure3, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure4, serializableBefore = false, serializableAfter = false)
+ verifyCleaning(closure5, serializableBefore = false, serializableAfter = false)
+ }
+
+ test("clean complicated nested serializable closures") {
+ val localValue = someSerializableValue
+
+ // Here we assume that if the outer closure is serializable,
+ // then all inner closures must also be serializable
+
+ // Reference local fields from all levels
+ val closure1 = (i: Int) => {
+ val a = 1
+ (1 to i).flatMap { x =>
+ val b = a + 1
+ (1 to x).map { y =>
+ y + a + b + localValue
+ }
+ }
+ }
+
+ // Reference local fields and methods from all levels within the outermost closure
+ val closure2 = (i: Int) => {
+ val a1 = 1
+ def a2 = 2
+ (1 to i).flatMap { x =>
+ val b1 = a1 + 1
+ def b2 = a2 + 1
+ (1 to x).map { y =>
+ // If this references a method outside the outermost closure, then it will try to pull
+ // in the ClosureCleanerSuite2. This is why `localValue` here must be a local `val`.
+ y + a1 + a2 + b1 + b2 + localValue
+ }
+ }
+ }
+
+ val closure1r = closure1(1)
+ val closure2r = closure2(2)
+ verifyCleaning(closure1, serializableBefore = true, serializableAfter = true)
+ verifyCleaning(closure2, serializableBefore = true, serializableAfter = true)
+ assert(closure1(1) == closure1r)
+ assert(closure2(2) == closure2r)
+ }
+
+ test("clean complicated nested non-serializable closures") {
+ val localValue = someSerializableValue
+
+ // Note that we are not interested in cleaning the outer closures here (they are not cleanable)
+ // The only reason why they exist is to nest the inner closures
+
+ val test1 = () => {
+ val a = localValue
+ val b = sc
+ val inner1 = (x: Int) => x + a + b.hashCode()
+ val inner2 = (x: Int) => x + a
+
+ // This closure explicitly references a non-serializable field
+ // There is no way to clean it
+ verifyCleaning(inner1, serializableBefore = false, serializableAfter = false)
+
+ // This closure is serializable to begin with since it does not need a pointer to
+ // the outer closure (it only references local variables)
+ verifyCleaning(inner2, serializableBefore = true, serializableAfter = true)
+ }
+
+ // Same as above, but the `val a` becomes `def a`
+ // The difference here is that all inner closures now have pointers to the outer closure
+ val test2 = () => {
+ def a = localValue
+ val b = sc
+ val inner1 = (x: Int) => x + a + b.hashCode()
+ val inner2 = (x: Int) => x + a
+
+ // As before, this closure is neither serializable nor cleanable
+ verifyCleaning(inner1, serializableBefore = false, serializableAfter = false)
+
+ // This closure is no longer serializable because it now has a pointer to the outer closure,
+ // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2.
+ // If we do not clean transitively, we will not null out this indirect reference.
+ verifyCleaning(
+ inner2, serializableBefore = false, serializableAfter = false, transitive = false)
+
+ // If we clean transitively, we will find that method `a` does not actually reference the
+ // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out
+ // the outer closure's parent pointer. This will make `inner2` serializable.
+ verifyCleaning(
+ inner2, serializableBefore = false, serializableAfter = true, transitive = true)
+ }
+
+ // Same as above, but with more levels of nesting
+ val test3 = () => { () => test1() }
+ val test4 = () => { () => test2() }
+ val test5 = () => { () => { () => test3() } }
+ val test6 = () => { () => { () => test4() } }
+
+ test1()
+ test2()
+ test3()()
+ test4()()
+ test5()()()
+ test6()()()
+ }
+
+}