diff options
author | Reynold Xin <rxin@databricks.com> | 2015-12-30 18:07:07 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-12-30 18:07:07 -0800 |
commit | ee8f8d318417c514fbb26e57157483d466ddbfae (patch) | |
tree | 7da3d291a1014f63789679f8a22c726ece3634de /core/src | |
parent | f76ee109d87e727710d2721e4be47fdabc21582c (diff) | |
download | spark-ee8f8d318417c514fbb26e57157483d466ddbfae.tar.gz spark-ee8f8d318417c514fbb26e57157483d466ddbfae.tar.bz2 spark-ee8f8d318417c514fbb26e57157483d466ddbfae.zip |
[SPARK-12588] Remove HttpBroadcast in Spark 2.0.
We switched to TorrentBroadcast in Spark 1.1, and HttpBroadcast has been undocumented since then. It's time to remove it in Spark 2.0.
Author: Reynold Xin <rxin@databricks.com>
Closes #10531 from rxin/SPARK-12588.
Diffstat (limited to 'core/src')
8 files changed, 13 insertions, 457 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6a187b4062..7f35ac4747 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,14 +24,12 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi /** - * :: DeveloperApi :: * An interface for all the broadcast implementations in Spark (to allow * multiple broadcast implementations). SparkContext uses a user-specified * BroadcastFactory implementation to instantiate a particular broadcast for the * entire Spark job. */ -@DeveloperApi -trait BroadcastFactory { +private[spark] trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index fac6666bb3..61343607a1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SecurityManager} + private[spark] class BroadcastManager( val isDriver: Boolean, @@ -39,15 +39,8 @@ private[spark] class BroadcastManager( private def initialize() { synchronized { if (!initialized) { - val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - - broadcastFactory = - Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory = new TorrentBroadcastFactory broadcastFactory.initialize(isDriver, conf, securityManager) - initialized = true } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala deleted file mode 100644 index b69af639f7..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} -import java.io.{BufferedInputStream, BufferedOutputStream} -import java.net.{URL, URLConnection, URI} -import java.util.concurrent.TimeUnit - -import scala.reflect.ClassTag - -import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} -import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} - -/** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server - * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a - * task) is deserialized in the executor, the broadcasted data is fetched from the driver - * (through a HTTP server running at the driver) and stored in the BlockManager of the - * executor to speed up future accesses. - */ -private[spark] class HttpBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - - override protected def getValue() = value_ - - private val blockId = BroadcastBlockId(id) - - /* - * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster - * does not need to be told about this block as not only need to know about this data block. - */ - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - } - - if (!isLocal) { - HttpBroadcast.write(id, value_) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors. - */ - override protected def doUnpersist(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors and driver. - */ - override protected def doDestroy(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) - } - - /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { - assertValid() - out.defaultWriteObject() - } - - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => value_ = x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) - /* - * We cache broadcast data in the BlockManager so that subsequent tasks using it - * do not need to re-fetch. This data is only used locally and no other node - * needs to fetch this block, so we don't notify the master. - */ - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - } -} - -private[broadcast] object HttpBroadcast extends Logging { - private var initialized = false - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - private var securityManager: SecurityManager = null - - // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[File] - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null - private var cleaner: MetadataCleaner = null - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - synchronized { - if (!initialized) { - bufferSize = conf.getInt("spark.buffer.size", 65536) - compress = conf.getBoolean("spark.broadcast.compress", true) - securityManager = securityMgr - if (isDriver) { - createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) - } - serverUri = conf.get("spark.httpBroadcast.uri") - cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true - } - } - } - - def stop() { - synchronized { - if (server != null) { - server.stop() - server = null - } - if (cleaner != null) { - cleaner.cancel() - cleaner = null - } - compressionCodec = null - initialized = false - } - } - - private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast") - val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = - new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") - server.start() - serverUri = server.uri - logInfo("Broadcast server started at " + serverUri) - } - - def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) - - private def write(id: Long, value: Any) { - val file = getFile(id) - val fileOutputStream = new FileOutputStream(file) - Utils.tryWithSafeFinally { - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(fileOutputStream) - } else { - new BufferedOutputStream(fileOutputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - Utils.tryWithSafeFinally { - serOut.writeObject(value) - } { - serOut.close() - } - files += file - } { - fileOutputStream.close() - } - } - - private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) - val url = serverUri + "/" + BroadcastBlockId(id).name - - var uc: URLConnection = null - if (securityManager.isAuthenticationEnabled()) { - logDebug("broadcast security enabled") - val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL.openConnection() - uc.setConnectTimeout(httpReadTimeout) - uc.setAllowUserInteraction(false) - } else { - logDebug("broadcast not using security") - uc = new URL(url).openConnection() - uc.setConnectTimeout(httpReadTimeout) - } - Utils.setupSecureURLConnection(uc, securityManager) - - val in = { - uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream - if (compress) { - compressionCodec.compressedInputStream(inputStream) - } else { - new BufferedInputStream(inputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - Utils.tryWithSafeFinally { - serIn.readObject[T]() - } { - serIn.close() - } - } - - /** - * Remove all persisted blocks associated with this HTTP broadcast on the executors. - * If removeFromDriver is true, also remove these persisted blocks on the driver - * and delete the associated broadcast file. - */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - if (removeFromDriver) { - val file = getFile(id) - files.remove(file) - deleteBroadcastFile(file) - } - } - - /** - * Periodically clean up old broadcasts by removing the associated map entries and - * deleting the associated files. - */ - private def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - iterator.remove() - deleteBroadcastFile(file) - } - } - } - - private def deleteBroadcastFile(file: File) { - try { - if (file.exists) { - if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) - } - } - } catch { - case e: Exception => - logError("Exception while deleting broadcast file: %s".format(file), e) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala deleted file mode 100644 index cf3ae36f27..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import scala.reflect.ClassTag - -import org.apache.spark.{SecurityManager, SparkConf} - -/** - * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a - * HTTP server as the broadcast mechanism. Refer to - * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. - */ -class HttpBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = - new HttpBroadcast[T](value_, isLocal, id) - - override def stop() { HttpBroadcast.stop() } - - /** - * Remove all persisted state associated with the HTTP broadcast with the given ID. - * @param removeFromDriver Whether to remove state from the driver - * @param blocking Whether to block until unbroadcasted - */ - override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver, blocking) - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7e3764d802..9bd69727f6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * BlockManager, ready for other executors to fetch from. * * This prevents the driver from being the bottleneck in sending out multiple copies of the - * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * broadcast data (one per executor). * * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 96d8dd7990..b11f9ba171 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SecurityManager, SparkConf} * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ -class TorrentBroadcastFactory extends BroadcastFactory { +private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index eed9937b30..1b4538e6af 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -34,7 +34,6 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -107,7 +106,6 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index ba21075ce6..88fdbbdaec 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -45,39 +45,8 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { class BroadcastSuite extends SparkFunSuite with LocalSparkContext { - private val httpConf = broadcastConf("HttpBroadcastFactory") - private val torrentConf = broadcastConf("TorrentBroadcastFactory") - - test("Using HttpBroadcast locally") { - sc = new SparkContext("local", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === Set((1, 10), (2, 10))) - } - - test("Accessing HttpBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) - } - - test("Accessing HttpBroadcast variables in a local cluster") { - val numSlaves = 4 - val conf = httpConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - test("Using TorrentBroadcast locally") { - sc = new SparkContext("local", "test", torrentConf) + sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) @@ -85,7 +54,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } test("Accessing TorrentBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", torrentConf) + sc = new SparkContext("local[10]", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) @@ -94,7 +63,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = torrentConf.clone + val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -124,31 +93,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 - val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") val rdd = sc.parallelize(1 to numSlaves) - val results = new DummyBroadcastClass(rdd).doSomething() assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) } - test("Unpersisting HttpBroadcast on executors only in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) - } - - test("Unpersisting HttpBroadcast on executors only in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) - } - test("Unpersisting TorrentBroadcast on executors only in local mode") { testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) } @@ -180,66 +131,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } /** - * Verify the persistence of state associated with an HttpBroadcast in either local mode or - * local-cluster mode (when distributed = true). - * - * This test creates a broadcast variable, uses it on all executors, and then unpersists it. - * In between each step, this test verifies that the broadcast blocks and the broadcast file - * are present only on the expected nodes. - */ - private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { - val numSlaves = if (distributed) 2 else 0 - - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.isDriver, "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } - if (distributed) { - // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") - } - } - - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === numSlaves + 1) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } - - // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver - // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - if (distributed && removeFromDriver) { - // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) - } - } - - testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, - afterUsingBroadcast, afterUnpersist, removeFromDriver) - } - - /** * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. * * This test creates a broadcast variable, uses it on all executors, and then unpersists it. @@ -284,7 +175,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -300,7 +191,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private def testUnpersistBroadcast( distributed: Boolean, numSlaves: Int, // used only when distributed = true - broadcastConf: SparkConf, afterCreation: (Long, BlockManagerMaster) => Unit, afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, afterUnpersist: (Long, BlockManagerMaster) => Unit, @@ -308,7 +198,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) @@ -319,7 +209,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { throw e } } else { - new SparkContext("local", "test", broadcastConf) + new SparkContext("local", "test") } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -356,13 +246,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } } - - /** Helper method to create a SparkConf that uses the given broadcast factory. */ - private def broadcastConf(factoryName: String): SparkConf = { - val conf = new SparkConf - conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) - conf - } } package object testPackage extends Assertions { |