aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala54
2 files changed, 28 insertions, 58 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 43626b4ef4..ebead830c6 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging {
cls.getName.contains("$anonfun$")
}
- // Get a list of the classes of the outer objects of a given closure object, obj;
+ // Get a list of the outer objects and their classes of a given closure object, obj;
// the outer objects are defined as any closures that obj is nested within, plus
// possibly the class that the outermost closure is in, if any. We stop searching
// for outer objects beyond that because cloning the user's object is probably
// not a good idea (whereas we can clone closure objects just fine since we
// understand how all their fields are used).
- private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
val outer = f.get(obj)
// The outer pointer may be null if we have cleaned this closure before
if (outer != null) {
if (isClosure(f.getType)) {
- return f.getType :: getOuterClasses(outer)
+ val recurRet = getOuterClassesAndObjects(outer)
+ return (f.getType :: recurRet._1, outer :: recurRet._2)
} else {
- return f.getType :: Nil // Stop at the first $outer that is not a closure
+ return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure
}
}
}
- Nil
+ (Nil, Nil)
}
-
- // Get a list of the outer objects for a given closure object.
- private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
- for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
- f.setAccessible(true)
- val outer = f.get(obj)
- // The outer pointer may be null if we have cleaned this closure before
- if (outer != null) {
- if (isClosure(f.getType)) {
- return outer :: getOuterObjects(outer)
- } else {
- return outer :: Nil // Stop at the first $outer that is not a closure
- }
- }
- }
- Nil
- }
-
/**
* Return a list of classes that represent closures enclosed in the given closure object.
*/
@@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging {
// A list of enclosing objects and their respective classes, from innermost to outermost
// An outer object at a given index is of type outer class at the same index
- val outerClasses = getOuterClasses(func)
- val outerObjects = getOuterObjects(func)
+ val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
// For logging purposes only
val declaredFields = func.getClass.getDeclaredFields
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 3147c93776..a829b09902 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
// Accessors for private methods
private val _isClosure = PrivateMethod[Boolean]('isClosure)
private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses)
- private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses)
- private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)
+ private val _getOuterClassesAndObjects =
+ PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects)
private def isClosure(obj: AnyRef): Boolean = {
ClosureCleaner invokePrivate _isClosure(obj)
@@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
ClosureCleaner invokePrivate _getInnerClosureClasses(closure)
}
- private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
- ClosureCleaner invokePrivate _getOuterClasses(closure)
- }
-
- private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
- ClosureCleaner invokePrivate _getOuterObjects(closure)
+ private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = {
+ ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure)
}
test("get inner closure classes") {
@@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => localValue
val closure3 = () => someSerializableValue
val closure4 = () => someSerializableMethod()
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
- val outerObjects4 = getOuterObjects(closure4)
+
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4)
// The classes and objects should have the same size
assert(outerClasses1.size === outerObjects1.size)
@@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val x = 1
val closure1 = () => 1
val closure2 = () => x
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
// These inner closures only reference local variables, and so do not have $outer pointers
@@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => y
val closure3 = () => localValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
assert(outerClasses3.size === outerObjects3.size)
@@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => localValue
val closure3 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
@@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => a
val closure3 = () => localValue
val closure4 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, _) = getOuterClassesAndObjects(closure4)
// First, find only fields accessed directly, not transitively, by these closures
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)