aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-04-14 18:09:23 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-04-14 18:10:05 -0700
commitadfe8ff8b24497aee1375054f1630750d01b30a2 (patch)
tree96d3d9ec33970158d311b627913f1bd2e3636b21
parent219200917a7febdda3290d74a6eb758df3d9d9e6 (diff)
downloadspark-adfe8ff8b24497aee1375054f1630750d01b30a2.tar.gz
spark-adfe8ff8b24497aee1375054f1630750d01b30a2.tar.bz2
spark-adfe8ff8b24497aee1375054f1630750d01b30a2.zip
Small changes to get ClosureCleanerSuite to pass.
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala104
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala79
3 files changed, 156 insertions, 32 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 489688cb08..fae1bd2a76 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -156,7 +156,8 @@ private[spark] object ClosureCleaner extends Logging {
accessedFields: Map[Class[_], Set[String]]): Unit = {
if (!isClosure(func.getClass)) {
- logWarning("Expected a closure; got " + func.getClass.getName)
+ // TODO: pass the other options as well
+ LambdaClosureCleaner.clean(func)
return
}
@@ -289,7 +290,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- private def ensureSerializable(func: AnyRef) {
+ private[util] def ensureSerializable(func: AnyRef) {
try {
if (SparkEnv.get != null) {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
diff --git a/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala
new file mode 100644
index 0000000000..96f3fc81fc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.lang.reflect.Method
+
+import org.apache.xbean.asm5.{ClassVisitor, MethodVisitor}
+import org.apache.xbean.asm5.Opcodes._
+
+import org.apache.spark.internal.Logging
+
+
+private[spark] object LambdaClosureCleaner extends Logging {
+
+ private[util] def clean(closure: AnyRef): Unit = {
+ val writeReplaceMethod: Method = try {
+ closure.getClass.getDeclaredMethod("writeReplace")
+ } catch {
+ case e: java.lang.NoSuchMethodException =>
+ logWarning("Expected a Java lambda; got " + closure.getClass.getName)
+ return
+ }
+
+ writeReplaceMethod.setAccessible(true)
+ // Because we still need to support Java 7, we must use reflection here.
+ val serializedLambda: AnyRef = writeReplaceMethod.invoke(closure)
+ if (serializedLambda.getClass.getName != "java.lang.invoke.SerializedLambda") {
+ logWarning("Closure's writeReplace() method " +
+ s"returned ${serializedLambda.getClass.getName}, not SerializedLambda")
+ return
+ }
+
+ val serializedLambdaClass = Utils.classForName("java.lang.invoke.SerializedLambda")
+
+ val implClassName = serializedLambdaClass
+ .getDeclaredMethod("getImplClass").invoke(serializedLambda).asInstanceOf[String]
+ // TODO: we do not want to unconditionally strip this suffix.
+ val implMethodName = {
+ serializedLambdaClass
+ .getDeclaredMethod("getImplMethodName").invoke(serializedLambda).asInstanceOf[String]
+ .stripSuffix("$adapted")
+ }
+ val implMethodSignature = serializedLambdaClass
+ .getDeclaredMethod("getImplMethodSignature").invoke(serializedLambda).asInstanceOf[String]
+ val capturedArgCount = serializedLambdaClass
+ .getDeclaredMethod("getCapturedArgCount").invoke(serializedLambda).asInstanceOf[Int]
+ val capturedArgs = (0 until capturedArgCount).map { argNum: Int =>
+ serializedLambdaClass
+ .getDeclaredMethod("getCapturedArg", java.lang.Integer.TYPE)
+ .invoke(serializedLambda, argNum.asInstanceOf[Object])
+ }.toSeq
+ assert(capturedArgs.size == capturedArgCount)
+ val implClass = Utils.classForName(implClassName.replaceAllLiterally("/", "."))
+
+ // Fail fast if we detect return statements in closures.
+ // TODO: match the impl method based on its type signature as well, not just its name.
+ ClosureCleaner
+ .getClassReader(implClass)
+ .accept(new LambdaReturnStatementFinder(implMethodName), 0)
+
+ // Check serializable TODO: add flag
+ ClosureCleaner.ensureSerializable(closure)
+ capturedArgs.foreach(ClosureCleaner.clean(_))
+
+ // TODO: null fields to render the closure serializable?
+ }
+}
+
+
+private class LambdaReturnStatementFinder(targetMethodName: String) extends ClassVisitor(ASM5) {
+ override def visitMethod(
+ access: Int,
+ name: String,
+ desc: String,
+ sig: String,
+ exceptions: Array[String]): MethodVisitor = {
+ if (name == targetMethodName) {
+ new MethodVisitor(ASM5) {
+ override def visitTypeInsn(op: Int, tp: String) {
+ if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) {
+ throw new ReturnStatementInClosureException
+ }
+ }
+ }
+ } else {
+ new MethodVisitor(ASM5) {}
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 934385fbca..d7682ab8cc 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -141,6 +141,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
}
test("get inner closure classes") {
+ // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore.
+ assume(!scala.util.Properties.versionString.startsWith("version 2.12."))
+
val closure1 = () => 1
val closure2 = () => { () => 1 }
val closure3 = (i: Int) => {
@@ -188,18 +191,20 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
assert(outerClasses1.isEmpty)
assert(outerClasses2.isEmpty)
- // These closures do have $outer pointers because they ultimately reference `this`
- // The first $outer pointer refers to the closure defines this test (see FunSuite#test)
- // The second $outer pointer refers to ClosureCleanerSuite2
- assert(outerClasses3.size === 2)
- assert(outerClasses4.size === 2)
- assert(isClosure(outerClasses3(0)))
- assert(isClosure(outerClasses4(0)))
- assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope
- assert(outerClasses3(1) === this.getClass)
- assert(outerClasses4(1) === this.getClass)
- assert(outerObjects3(1) === this)
- assert(outerObjects4(1) === this)
+ if (!scala.util.Properties.versionString.startsWith("version 2.12.")) {
+ // These closures do have $outer pointers because they ultimately reference `this`
+ // The first $outer pointer refers to the closure defines this test (see FunSuite#test)
+ // The second $outer pointer refers to ClosureCleanerSuite2
+ assert(outerClasses3.size === 2)
+ assert(outerClasses4.size === 2)
+ assert(isClosure(outerClasses3(0)))
+ assert(isClosure(outerClasses4(0)))
+ assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope
+ assert(outerClasses3(1) === this.getClass)
+ assert(outerClasses4(1) === this.getClass)
+ assert(outerObjects3(1) === this)
+ assert(outerObjects4(1) === this)
+ }
}
test("get outer classes and objects with nesting") {
@@ -231,22 +236,24 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
assert(outerClasses3.size === outerObjects3.size)
// Same as above, this closure only references local variables
assert(outerClasses1.isEmpty)
- // This closure references the "test2" scope because it needs to find the method `y`
- // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
- assert(outerClasses2.size === 3)
- // This closure references the "test2" scope because it needs to find the `localValue`
- // defined outside of this scope
- assert(outerClasses3.size === 3)
- assert(isClosure(outerClasses2(0)))
- assert(isClosure(outerClasses3(0)))
- assert(isClosure(outerClasses2(1)))
- assert(isClosure(outerClasses3(1)))
- assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope
- assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope
- assert(outerClasses2(2) === this.getClass)
- assert(outerClasses3(2) === this.getClass)
- assert(outerObjects2(2) === this)
- assert(outerObjects3(2) === this)
+ if (!scala.util.Properties.versionString.startsWith("version 2.12.")) {
+ // This closure references the "test2" scope because it needs to find the method `y`
+ // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
+ assert(outerClasses2.size === 3)
+ // This closure references the "test2" scope because it needs to find the `localValue`
+ // defined outside of this scope
+ assert(outerClasses3.size === 3)
+ assert(isClosure(outerClasses2(0)))
+ assert(isClosure(outerClasses3(0)))
+ assert(isClosure(outerClasses2(1)))
+ assert(isClosure(outerClasses3(1)))
+ assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope
+ assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope
+ assert(outerClasses2(2) === this.getClass)
+ assert(outerClasses3(2) === this.getClass)
+ assert(outerObjects2(2) === this)
+ assert(outerObjects3(2) === this)
+ }
}
test1()
@@ -254,6 +261,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
}
test("find accessed fields") {
+ // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore.
+ assume(!scala.util.Properties.versionString.startsWith("version 2.12."))
+
val localValue = someSerializableValue
val closure1 = () => 1
val closure2 = () => localValue
@@ -292,6 +302,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
}
test("find accessed fields with nesting") {
+ // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore.
+ assume(!scala.util.Properties.versionString.startsWith("version 2.12."))
+
val localValue = someSerializableValue
val test1 = () => {
@@ -538,13 +551,19 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
// which is itself not serializable because it has a pointer to the ClosureCleanerSuite2.
// If we do not clean transitively, we will not null out this indirect reference.
verifyCleaning(
- inner2, serializableBefore = false, serializableAfter = false, transitive = false)
+ inner2,
+ serializableBefore = scala.util.Properties.versionString.startsWith("version 2.12."),
+ serializableAfter = scala.util.Properties.versionString.startsWith("version 2.12."),
+ transitive = false)
// If we clean transitively, we will find that method `a` does not actually reference the
// outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out
// the outer closure's parent pointer. This will make `inner2` serializable.
verifyCleaning(
- inner2, serializableBefore = false, serializableAfter = true, transitive = true)
+ inner2,
+ serializableBefore = scala.util.Properties.versionString.startsWith("version 2.12."),
+ serializableAfter = true,
+ transitive = true)
}
// Same as above, but with more levels of nesting