diff options
author | Josh Rosen <joshrosen@databricks.com> | 2016-04-14 18:09:23 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2016-04-14 18:10:05 -0700 |
commit | adfe8ff8b24497aee1375054f1630750d01b30a2 (patch) | |
tree | 96d3d9ec33970158d311b627913f1bd2e3636b21 | |
parent | 219200917a7febdda3290d74a6eb758df3d9d9e6 (diff) | |
download | spark-adfe8ff8b24497aee1375054f1630750d01b30a2.tar.gz spark-adfe8ff8b24497aee1375054f1630750d01b30a2.tar.bz2 spark-adfe8ff8b24497aee1375054f1630750d01b30a2.zip |
Small changes to get ClosureCleanerSuite to pass.
3 files changed, 156 insertions, 32 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 489688cb08..fae1bd2a76 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -156,7 +156,8 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + // TODO: pass the other options as well + LambdaClosureCleaner.clean(func) return } @@ -289,7 +290,7 @@ private[spark] object ClosureCleaner extends Logging { } } - private def ensureSerializable(func: AnyRef) { + private[util] def ensureSerializable(func: AnyRef) { try { if (SparkEnv.get != null) { SparkEnv.get.closureSerializer.newInstance().serialize(func) diff --git a/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala new file mode 100644 index 0000000000..96f3fc81fc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/LambdaClosureCleaner.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.reflect.Method + +import org.apache.xbean.asm5.{ClassVisitor, MethodVisitor} +import org.apache.xbean.asm5.Opcodes._ + +import org.apache.spark.internal.Logging + + +private[spark] object LambdaClosureCleaner extends Logging { + + private[util] def clean(closure: AnyRef): Unit = { + val writeReplaceMethod: Method = try { + closure.getClass.getDeclaredMethod("writeReplace") + } catch { + case e: java.lang.NoSuchMethodException => + logWarning("Expected a Java lambda; got " + closure.getClass.getName) + return + } + + writeReplaceMethod.setAccessible(true) + // Because we still need to support Java 7, we must use reflection here. + val serializedLambda: AnyRef = writeReplaceMethod.invoke(closure) + if (serializedLambda.getClass.getName != "java.lang.invoke.SerializedLambda") { + logWarning("Closure's writeReplace() method " + + s"returned ${serializedLambda.getClass.getName}, not SerializedLambda") + return + } + + val serializedLambdaClass = Utils.classForName("java.lang.invoke.SerializedLambda") + + val implClassName = serializedLambdaClass + .getDeclaredMethod("getImplClass").invoke(serializedLambda).asInstanceOf[String] + // TODO: we do not want to unconditionally strip this suffix. + val implMethodName = { + serializedLambdaClass + .getDeclaredMethod("getImplMethodName").invoke(serializedLambda).asInstanceOf[String] + .stripSuffix("$adapted") + } + val implMethodSignature = serializedLambdaClass + .getDeclaredMethod("getImplMethodSignature").invoke(serializedLambda).asInstanceOf[String] + val capturedArgCount = serializedLambdaClass + .getDeclaredMethod("getCapturedArgCount").invoke(serializedLambda).asInstanceOf[Int] + val capturedArgs = (0 until capturedArgCount).map { argNum: Int => + serializedLambdaClass + .getDeclaredMethod("getCapturedArg", java.lang.Integer.TYPE) + .invoke(serializedLambda, argNum.asInstanceOf[Object]) + }.toSeq + assert(capturedArgs.size == capturedArgCount) + val implClass = Utils.classForName(implClassName.replaceAllLiterally("/", ".")) + + // Fail fast if we detect return statements in closures. + // TODO: match the impl method based on its type signature as well, not just its name. + ClosureCleaner + .getClassReader(implClass) + .accept(new LambdaReturnStatementFinder(implMethodName), 0) + + // Check serializable TODO: add flag + ClosureCleaner.ensureSerializable(closure) + capturedArgs.foreach(ClosureCleaner.clean(_)) + + // TODO: null fields to render the closure serializable? + } +} + + +private class LambdaReturnStatementFinder(targetMethodName: String) extends ClassVisitor(ASM5) { + override def visitMethod( + access: Int, + name: String, + desc: String, + sig: String, + exceptions: Array[String]): MethodVisitor = { + if (name == targetMethodName) { + new MethodVisitor(ASM5) { + override def visitTypeInsn(op: Int, tp: String) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + throw new ReturnStatementInClosureException + } + } + } + } else { + new MethodVisitor(ASM5) {} + } + } +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 934385fbca..d7682ab8cc 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -141,6 +141,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("get inner closure classes") { + // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore. + assume(!scala.util.Properties.versionString.startsWith("version 2.12.")) + val closure1 = () => 1 val closure2 = () => { () => 1 } val closure3 = (i: Int) => { @@ -188,18 +191,20 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri assert(outerClasses1.isEmpty) assert(outerClasses2.isEmpty) - // These closures do have $outer pointers because they ultimately reference `this` - // The first $outer pointer refers to the closure defines this test (see FunSuite#test) - // The second $outer pointer refers to ClosureCleanerSuite2 - assert(outerClasses3.size === 2) - assert(outerClasses4.size === 2) - assert(isClosure(outerClasses3(0))) - assert(isClosure(outerClasses4(0))) - assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope - assert(outerClasses3(1) === this.getClass) - assert(outerClasses4(1) === this.getClass) - assert(outerObjects3(1) === this) - assert(outerObjects4(1) === this) + if (!scala.util.Properties.versionString.startsWith("version 2.12.")) { + // These closures do have $outer pointers because they ultimately reference `this` + // The first $outer pointer refers to the closure defines this test (see FunSuite#test) + // The second $outer pointer refers to ClosureCleanerSuite2 + assert(outerClasses3.size === 2) + assert(outerClasses4.size === 2) + assert(isClosure(outerClasses3(0))) + assert(isClosure(outerClasses4(0))) + assert(outerClasses3(0) === outerClasses4(0)) // part of the same "FunSuite#test" scope + assert(outerClasses3(1) === this.getClass) + assert(outerClasses4(1) === this.getClass) + assert(outerObjects3(1) === this) + assert(outerObjects4(1) === this) + } } test("get outer classes and objects with nesting") { @@ -231,22 +236,24 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri assert(outerClasses3.size === outerObjects3.size) // Same as above, this closure only references local variables assert(outerClasses1.isEmpty) - // This closure references the "test2" scope because it needs to find the method `y` - // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2 - assert(outerClasses2.size === 3) - // This closure references the "test2" scope because it needs to find the `localValue` - // defined outside of this scope - assert(outerClasses3.size === 3) - assert(isClosure(outerClasses2(0))) - assert(isClosure(outerClasses3(0))) - assert(isClosure(outerClasses2(1))) - assert(isClosure(outerClasses3(1))) - assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope - assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope - assert(outerClasses2(2) === this.getClass) - assert(outerClasses3(2) === this.getClass) - assert(outerObjects2(2) === this) - assert(outerObjects3(2) === this) + if (!scala.util.Properties.versionString.startsWith("version 2.12.")) { + // This closure references the "test2" scope because it needs to find the method `y` + // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2 + assert(outerClasses2.size === 3) + // This closure references the "test2" scope because it needs to find the `localValue` + // defined outside of this scope + assert(outerClasses3.size === 3) + assert(isClosure(outerClasses2(0))) + assert(isClosure(outerClasses3(0))) + assert(isClosure(outerClasses2(1))) + assert(isClosure(outerClasses3(1))) + assert(outerClasses2(0) === outerClasses3(0)) // part of the same "test2" scope + assert(outerClasses2(1) === outerClasses3(1)) // part of the same "FunSuite#test" scope + assert(outerClasses2(2) === this.getClass) + assert(outerClasses3(2) === this.getClass) + assert(outerObjects2(2) === this) + assert(outerObjects3(2) === this) + } } test1() @@ -254,6 +261,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields") { + // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore. + assume(!scala.util.Properties.versionString.startsWith("version 2.12.")) + val localValue = someSerializableValue val closure1 = () => 1 val closure2 = () => localValue @@ -292,6 +302,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri } test("find accessed fields with nesting") { + // Skip on Scala 2.12 for now, since we may not need to determine accessed fields anymore. + assume(!scala.util.Properties.versionString.startsWith("version 2.12.")) + val localValue = someSerializableValue val test1 = () => { @@ -538,13 +551,19 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // which is itself not serializable because it has a pointer to the ClosureCleanerSuite2. // If we do not clean transitively, we will not null out this indirect reference. verifyCleaning( - inner2, serializableBefore = false, serializableAfter = false, transitive = false) + inner2, + serializableBefore = scala.util.Properties.versionString.startsWith("version 2.12."), + serializableAfter = scala.util.Properties.versionString.startsWith("version 2.12."), + transitive = false) // If we clean transitively, we will find that method `a` does not actually reference the // outer closure's parent (i.e. the ClosureCleanerSuite), so we can additionally null out // the outer closure's parent pointer. This will make `inner2` serializable. verifyCleaning( - inner2, serializableBefore = false, serializableAfter = true, transitive = true) + inner2, + serializableBefore = scala.util.Properties.versionString.startsWith("version 2.12."), + serializableAfter = true, + transitive = true) } // Same as above, but with more levels of nesting |