aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-12-30 18:07:07 -0800
committerReynold Xin <rxin@databricks.com>2015-12-30 18:07:07 -0800
commitee8f8d318417c514fbb26e57157483d466ddbfae (patch)
tree7da3d291a1014f63789679f8a22c726ece3634de /core/src
parentf76ee109d87e727710d2721e4be47fdabc21582c (diff)
downloadspark-ee8f8d318417c514fbb26e57157483d466ddbfae.tar.gz
spark-ee8f8d318417c514fbb26e57157483d466ddbfae.tar.bz2
spark-ee8f8d318417c514fbb26e57157483d466ddbfae.zip
[SPARK-12588] Remove HttpBroadcast in Spark 2.0.
We switched to TorrentBroadcast in Spark 1.1, and HttpBroadcast has been undocumented since then. It's time to remove it in Spark 2.0. Author: Reynold Xin <rxin@databricks.com> Closes #10531 from rxin/SPARK-12588.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala269
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala131
8 files changed, 13 insertions, 457 deletions
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 6a187b4062..7f35ac4747 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -24,14 +24,12 @@ import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
/**
- * :: DeveloperApi ::
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
-@DeveloperApi
-trait BroadcastFactory {
+private[spark] trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index fac6666bb3..61343607a1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong
import scala.reflect.ClassTag
-import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+
private[spark] class BroadcastManager(
val isDriver: Boolean,
@@ -39,15 +39,8 @@ private[spark] class BroadcastManager(
private def initialize() {
synchronized {
if (!initialized) {
- val broadcastFactoryClass =
- conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
-
- broadcastFactory =
- Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
-
- // Initialize appropriate BroadcastFactory and BroadcastObject
+ broadcastFactory = new TorrentBroadcastFactory
broadcastFactory.initialize(isDriver, conf, securityManager)
-
initialized = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
deleted file mode 100644
index b69af639f7..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
-import java.io.{BufferedInputStream, BufferedOutputStream}
-import java.net.{URL, URLConnection, URI}
-import java.util.concurrent.TimeUnit
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
-import org.apache.spark.io.CompressionCodec
-import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
-
-/**
- * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
- * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
- * task) is deserialized in the executor, the broadcasted data is fetched from the driver
- * (through a HTTP server running at the driver) and stored in the BlockManager of the
- * executor to speed up future accesses.
- */
-private[spark] class HttpBroadcast[T: ClassTag](
- @transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id) with Logging with Serializable {
-
- override protected def getValue() = value_
-
- private val blockId = BroadcastBlockId(id)
-
- /*
- * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
- * does not need to be told about this block as not only need to know about this data block.
- */
- HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
- }
-
- if (!isLocal) {
- HttpBroadcast.write(id, value_)
- }
-
- /**
- * Remove all persisted state associated with this HTTP broadcast on the executors.
- */
- override protected def doUnpersist(blocking: Boolean) {
- HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
- }
-
- /**
- * Remove all persisted state associated with this HTTP broadcast on the executors and driver.
- */
- override protected def doDestroy(blocking: Boolean) {
- HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
- }
-
- /** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
- assertValid()
- out.defaultWriteObject()
- }
-
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
- in.defaultReadObject()
- HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) => value_ = x.asInstanceOf[T]
- case None => {
- logInfo("Started reading broadcast variable " + id)
- val start = System.nanoTime
- value_ = HttpBroadcast.read[T](id)
- /*
- * We cache broadcast data in the BlockManager so that subsequent tasks using it
- * do not need to re-fetch. This data is only used locally and no other node
- * needs to fetch this block, so we don't notify the master.
- */
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
- }
-}
-
-private[broadcast] object HttpBroadcast extends Logging {
- private var initialized = false
- private var broadcastDir: File = null
- private var compress: Boolean = false
- private var bufferSize: Int = 65536
- private var serverUri: String = null
- private var server: HttpServer = null
- private var securityManager: SecurityManager = null
-
- // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
- private val files = new TimeStampedHashSet[File]
- private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
- private var compressionCodec: CompressionCodec = null
- private var cleaner: MetadataCleaner = null
-
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- synchronized {
- if (!initialized) {
- bufferSize = conf.getInt("spark.buffer.size", 65536)
- compress = conf.getBoolean("spark.broadcast.compress", true)
- securityManager = securityMgr
- if (isDriver) {
- createServer(conf)
- conf.set("spark.httpBroadcast.uri", serverUri)
- }
- serverUri = conf.get("spark.httpBroadcast.uri")
- cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
- compressionCodec = CompressionCodec.createCodec(conf)
- initialized = true
- }
- }
- }
-
- def stop() {
- synchronized {
- if (server != null) {
- server.stop()
- server = null
- }
- if (cleaner != null) {
- cleaner.cancel()
- cleaner = null
- }
- compressionCodec = null
- initialized = false
- }
- }
-
- private def createServer(conf: SparkConf) {
- broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
- val broadcastPort = conf.getInt("spark.broadcast.port", 0)
- server =
- new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
- server.start()
- serverUri = server.uri
- logInfo("Broadcast server started at " + serverUri)
- }
-
- def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name)
-
- private def write(id: Long, value: Any) {
- val file = getFile(id)
- val fileOutputStream = new FileOutputStream(file)
- Utils.tryWithSafeFinally {
- val out: OutputStream = {
- if (compress) {
- compressionCodec.compressedOutputStream(fileOutputStream)
- } else {
- new BufferedOutputStream(fileOutputStream, bufferSize)
- }
- }
- val ser = SparkEnv.get.serializer.newInstance()
- val serOut = ser.serializeStream(out)
- Utils.tryWithSafeFinally {
- serOut.writeObject(value)
- } {
- serOut.close()
- }
- files += file
- } {
- fileOutputStream.close()
- }
- }
-
- private def read[T: ClassTag](id: Long): T = {
- logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
- val url = serverUri + "/" + BroadcastBlockId(id).name
-
- var uc: URLConnection = null
- if (securityManager.isAuthenticationEnabled()) {
- logDebug("broadcast security enabled")
- val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
- uc = newuri.toURL.openConnection()
- uc.setConnectTimeout(httpReadTimeout)
- uc.setAllowUserInteraction(false)
- } else {
- logDebug("broadcast not using security")
- uc = new URL(url).openConnection()
- uc.setConnectTimeout(httpReadTimeout)
- }
- Utils.setupSecureURLConnection(uc, securityManager)
-
- val in = {
- uc.setReadTimeout(httpReadTimeout)
- val inputStream = uc.getInputStream
- if (compress) {
- compressionCodec.compressedInputStream(inputStream)
- } else {
- new BufferedInputStream(inputStream, bufferSize)
- }
- }
- val ser = SparkEnv.get.serializer.newInstance()
- val serIn = ser.deserializeStream(in)
- Utils.tryWithSafeFinally {
- serIn.readObject[T]()
- } {
- serIn.close()
- }
- }
-
- /**
- * Remove all persisted blocks associated with this HTTP broadcast on the executors.
- * If removeFromDriver is true, also remove these persisted blocks on the driver
- * and delete the associated broadcast file.
- */
- def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized {
- SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
- if (removeFromDriver) {
- val file = getFile(id)
- files.remove(file)
- deleteBroadcastFile(file)
- }
- }
-
- /**
- * Periodically clean up old broadcasts by removing the associated map entries and
- * deleting the associated files.
- */
- private def cleanup(cleanupTime: Long) {
- val iterator = files.internalMap.entrySet().iterator()
- while(iterator.hasNext) {
- val entry = iterator.next()
- val (file, time) = (entry.getKey, entry.getValue)
- if (time < cleanupTime) {
- iterator.remove()
- deleteBroadcastFile(file)
- }
- }
- }
-
- private def deleteBroadcastFile(file: File) {
- try {
- if (file.exists) {
- if (file.delete()) {
- logInfo("Deleted broadcast file: %s".format(file))
- } else {
- logWarning("Could not delete broadcast file: %s".format(file))
- }
- }
- } catch {
- case e: Exception =>
- logError("Exception while deleting broadcast file: %s".format(file), e)
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
deleted file mode 100644
index cf3ae36f27..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{SecurityManager, SparkConf}
-
-/**
- * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a
- * HTTP server as the broadcast mechanism. Refer to
- * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
- */
-class HttpBroadcastFactory extends BroadcastFactory {
- override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- HttpBroadcast.initialize(isDriver, conf, securityMgr)
- }
-
- override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] =
- new HttpBroadcast[T](value_, isLocal, id)
-
- override def stop() { HttpBroadcast.stop() }
-
- /**
- * Remove all persisted state associated with the HTTP broadcast with the given ID.
- * @param removeFromDriver Whether to remove state from the driver
- * @param blocking Whether to block until unbroadcasted
- */
- override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
- HttpBroadcast.unpersist(id, removeFromDriver, blocking)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 7e3764d802..9bd69727f6 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
* BlockManager, ready for other executors to fetch from.
*
* This prevents the driver from being the bottleneck in sending out multiple copies of the
- * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
+ * broadcast data (one per executor).
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index 96d8dd7990..b11f9ba171 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
* protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
* [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
*/
-class TorrentBroadcastFactory extends BroadcastFactory {
+private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index eed9937b30..1b4538e6af 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -34,7 +34,6 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
-import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
@@ -107,7 +106,6 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer())
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index ba21075ce6..88fdbbdaec 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -45,39 +45,8 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
- private val httpConf = broadcastConf("HttpBroadcastFactory")
- private val torrentConf = broadcastConf("TorrentBroadcastFactory")
-
- test("Using HttpBroadcast locally") {
- sc = new SparkContext("local", "test", httpConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === Set((1, 10), (2, 10)))
- }
-
- test("Accessing HttpBroadcast variables from multiple threads") {
- sc = new SparkContext("local[10]", "test", httpConf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
- }
-
- test("Accessing HttpBroadcast variables in a local cluster") {
- val numSlaves = 4
- val conf = httpConf.clone
- conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- conf.set("spark.broadcast.compress", "true")
- sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
- val list = List[Int](1, 2, 3, 4)
- val broadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
- assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
- }
-
test("Using TorrentBroadcast locally") {
- sc = new SparkContext("local", "test", torrentConf)
+ sc = new SparkContext("local", "test")
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
@@ -85,7 +54,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
}
test("Accessing TorrentBroadcast variables from multiple threads") {
- sc = new SparkContext("local[10]", "test", torrentConf)
+ sc = new SparkContext("local[10]", "test")
val list = List[Int](1, 2, 3, 4)
val broadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
@@ -94,7 +63,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
test("Accessing TorrentBroadcast variables in a local cluster") {
val numSlaves = 4
- val conf = torrentConf.clone
+ val conf = new SparkConf
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.broadcast.compress", "true")
sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
@@ -124,31 +93,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
test("Test Lazy Broadcast variables with TorrentBroadcast") {
val numSlaves = 2
- val conf = torrentConf.clone
- sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf)
+ sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test")
val rdd = sc.parallelize(1 to numSlaves)
-
val results = new DummyBroadcastClass(rdd).doSomething()
assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet)
}
- test("Unpersisting HttpBroadcast on executors only in local mode") {
- testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
- }
-
- test("Unpersisting HttpBroadcast on executors and driver in local mode") {
- testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
- }
-
- test("Unpersisting HttpBroadcast on executors only in distributed mode") {
- testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
- }
-
- test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
- testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
- }
-
test("Unpersisting TorrentBroadcast on executors only in local mode") {
testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
}
@@ -180,66 +131,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
}
/**
- * Verify the persistence of state associated with an HttpBroadcast in either local mode or
- * local-cluster mode (when distributed = true).
- *
- * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
- * In between each step, this test verifies that the broadcast blocks and the broadcast file
- * are present only on the expected nodes.
- */
- private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
- val numSlaves = if (distributed) 2 else 0
-
- // Verify that the broadcast file is created, and blocks are persisted only on the driver
- def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
- val blockId = BroadcastBlockId(broadcastId)
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === 1)
- statuses.head match { case (bm, status) =>
- assert(bm.isDriver, "Block should only be on the driver")
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store on the driver")
- assert(status.diskSize === 0, "Block should not be in disk store on the driver")
- }
- if (distributed) {
- // this file is only generated in distributed mode
- assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
- }
- }
-
- // Verify that blocks are persisted in both the executors and the driver
- def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
- val blockId = BroadcastBlockId(broadcastId)
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === numSlaves + 1)
- statuses.foreach { case (_, status) =>
- assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
- assert(status.memSize > 0, "Block should be in memory store")
- assert(status.diskSize === 0, "Block should not be in disk store")
- }
- }
-
- // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
- // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
- def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
- val blockId = BroadcastBlockId(broadcastId)
- val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- val expectedNumBlocks = if (removeFromDriver) 0 else 1
- val possiblyNot = if (removeFromDriver) "" else " not"
- assert(statuses.size === expectedNumBlocks,
- "Block should%s be unpersisted on the driver".format(possiblyNot))
- if (distributed && removeFromDriver) {
- // this file is only generated in distributed mode
- assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
- "Broadcast file should%s be deleted".format(possiblyNot))
- }
- }
-
- testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
- afterUsingBroadcast, afterUnpersist, removeFromDriver)
- }
-
- /**
* Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
*
* This test creates a broadcast variable, uses it on all executors, and then unpersists it.
@@ -284,7 +175,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
assert(statuses.size === expectedNumBlocks)
}
- testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation,
+ testUnpersistBroadcast(distributed, numSlaves, afterCreation,
afterUsingBroadcast, afterUnpersist, removeFromDriver)
}
@@ -300,7 +191,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
private def testUnpersistBroadcast(
distributed: Boolean,
numSlaves: Int, // used only when distributed = true
- broadcastConf: SparkConf,
afterCreation: (Long, BlockManagerMaster) => Unit,
afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
afterUnpersist: (Long, BlockManagerMaster) => Unit,
@@ -308,7 +198,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
sc = if (distributed) {
val _sc =
- new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf)
+ new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test")
// Wait until all salves are up
try {
_sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000)
@@ -319,7 +209,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
throw e
}
} else {
- new SparkContext("local", "test", broadcastConf)
+ new SparkContext("local", "test")
}
val blockManagerMaster = sc.env.blockManager.master
val list = List[Int](1, 2, 3, 4)
@@ -356,13 +246,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
}
}
-
- /** Helper method to create a SparkConf that uses the given broadcast factory. */
- private def broadcastConf(factoryName: String): SparkConf = {
- val conf = new SparkConf
- conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
- conf
- }
}
package object testPackage extends Assertions {