aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala94
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala68
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala8
8 files changed, 14 insertions, 218 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 545807ffbc..76305237b0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1002,9 +1002,7 @@ class SparkContext(config: SparkConf) extends Logging {
require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
}
val callSite = getCallSite
- // There's no need to check this function for serializability,
- // since it will be run right away.
- val cleanedFunc = clean(func, false)
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
@@ -1137,18 +1135,14 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
-
+
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
- *
- * @param f closure to be cleaned and optionally serialized
- * @param captureNow whether or not to serialize this closure and capture any free
- * variables immediately; defaults to true. If this is set and f is not serializable,
- * it will raise an exception.
*/
- private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = {
- ClosureCleaner.clean(f, captureNow)
+ private[spark] def clean[F <: AnyRef](f: F): F = {
+ ClosureCleaner.clean(f)
+ f
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e363ea777d..3437b2cac1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -660,16 +660,14 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index e474b1a850..cdbbc65292 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -22,14 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set
-import scala.reflect.ClassTag
-
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
-import org.apache.spark.SparkException
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -105,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}
- def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = {
+ def clean(func: AnyRef) {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
@@ -154,21 +150,6 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}
-
- if (captureNow) {
- cloneViaSerializing(func)
- } else {
- func
- }
- }
-
- private def cloneViaSerializing[T: ClassTag](func: T): T = {
- try {
- val serializer = SparkEnv.get.closureSerializer.newInstance()
- serializer.deserialize[T](serializer.serialize[T](func))
- } catch {
- case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString)
- }
}
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 4f9300419e..12dbebcb28 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
- test("failure because closure in final-stage task is not serializable") {
+ test("failure because task closure is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable
@@ -118,13 +118,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
- FailureSuiteState.clear()
- }
-
- test("failure because closure in early-stage task is not serializable") {
- sc = new SparkContext("local[1,1]", "test")
- val a = new NonSerializable
-
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
@@ -132,13 +125,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))
- FailureSuiteState.clear()
- }
-
- test("failure because closure in foreach task is not serializable") {
- sc = new SparkContext("local[1,1]", "test")
- val a = new NonSerializable
-
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
@@ -149,6 +135,5 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
-
// TODO: Need to add tests with shuffle fetch failures.
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
deleted file mode 100644
index 76662264e7..0000000000
--- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.serializer;
-
-import java.io.NotSerializableException
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkException
-import org.apache.spark.SharedSparkContext
-
-/* A trivial (but unserializable) container for trivial functions */
-class UnserializableClass {
- def op[T](x: T) = x.toString
-
- def pred[T](x: T) = x.toString.length % 2 == 0
-}
-
-class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {
-
- def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
-
- test("throws expected serialization exceptions on actions") {
- val (data, uc) = fixture
-
- val ex = intercept[SparkException] {
- data.map(uc.op(_)).count
- }
-
- assert(ex.getMessage.matches(".*Task not serializable.*"))
- }
-
- // There is probably a cleaner way to eliminate boilerplate here, but we're
- // iterating over a map from transformation names to functions that perform that
- // transformation on a given RDD, creating one test case for each
-
- for (transformation <-
- Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _,
- "mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
- "mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) {
- val (name, xf) = transformation
-
- test(s"$name transformations throw proactive serialization exceptions") {
- val (data, uc) = fixture
-
- val ex = intercept[SparkException] {
- xf(data, uc)
- }
-
- assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException")
- }
- }
-
- def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.map(y => uc.op(y))
-
- def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapWith(x => x.toString)((x,y) => x + uc.op(y))
-
- def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.flatMap(y=>Seq(uc.op(y)))
-
- def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.filter(y=>uc.pred(y))
-
- def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.filterWith(x => x.toString)((x,y) => uc.pred(y))
-
- def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitions(_.map(y => uc.op(y)))
-
- def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))
-
- def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))
-
-}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index c635da6cac..439e5644e2 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -50,27 +50,6 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}
-
- test("capturing free variables in closures at RDD definition") {
- val obj = new TestCaptureVarClass()
- val (ones, onesPlusZeroes) = obj.run()
-
- assert(ones === onesPlusZeroes)
- }
-
- test("capturing free variable fields in closures at RDD definition") {
- val obj = new TestCaptureFieldClass()
- val (ones, onesPlusZeroes) = obj.run()
-
- assert(ones === onesPlusZeroes)
- }
-
- test("capturing arrays in closures at RDD definition") {
- val obj = new TestCaptureArrayEltClass()
- val (observed, expected) = obj.run()
-
- assert(observed === expected)
- }
}
// A non-serializable class we create in closures to make sure that we aren't
@@ -164,50 +143,3 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}
-
-class TestCaptureFieldClass extends Serializable {
- class ZeroBox extends Serializable {
- var zero = 0
- }
-
- def run(): (Int, Int) = {
- val zb = new ZeroBox
-
- withSpark(new SparkContext("local", "test")) {sc =>
- val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
- val onesPlusZeroes = ones.map(_ + zb.zero)
-
- zb.zero = 5
-
- (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
- }
- }
-}
-
-class TestCaptureArrayEltClass extends Serializable {
- def run(): (Int, Int) = {
- withSpark(new SparkContext("local", "test")) {sc =>
- val rdd = sc.parallelize(1 to 10)
- val data = Array(1, 2, 3)
- val expected = data(0)
- val mapped = rdd.map(x => data(0))
- data(0) = 4
- (mapped.first, expected)
- }
- }
-}
-
-class TestCaptureVarClass extends Serializable {
- def run(): (Int, Int) = {
- var zero = 0
-
- withSpark(new SparkContext("local", "test")) {sc =>
- val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
- val onesPlusZeroes = ones.map(_ + zero)
-
- zero = 5
-
- (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
- }
- }
-}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index c65e36636f..28d34dd9a1 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
- graph.triplets.collect.map { et =>
+ graph.triplets.map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 4759b629a9..d043200f71 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -539,7 +539,7 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
+ transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
}
/**
@@ -547,7 +547,7 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- val cleanedF = context.sparkContext.clean(transformFunc, false)
+ val cleanedF = context.sparkContext.clean(transformFunc)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
@@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc, false)
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}
@@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
- val cleanedF = ssc.sparkContext.clean(transformFunc, false)
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]