aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java19
-rwxr-xr-xcore/src/main/java/org/apache/spark/network/netty/PathResolver.java11
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala168
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala184
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala301
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala6
-rw-r--r--examples/pom.xml28
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java98
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala28
-rw-r--r--project/SparkBuild.scala9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jarbin1358063 -> 0 bytes
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml12
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha11
-rw-r--r--streaming/pom.xml38
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala33
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala61
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java16
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala8
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala32
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala7
39 files changed, 848 insertions, 543 deletions
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index cfd8132891..172c6e4b1c 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -25,6 +25,7 @@ import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -37,40 +38,34 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
- String path = pResolver.getAbsolutePath(blockId.name());
- // if getFilePath returns null, close the channel
- if (path == null) {
+ FileSegment fileSegment = pResolver.getBlockLocation(blockId);
+ // if getBlockLocation returns null, close the channel
+ if (fileSegment == null) {
//ctx.close();
return;
}
- File file = new File(path);
+ File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
- //logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
- long length = file.length();
+ long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
- //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
- //logger.info("Sending block "+blockId+" filelen = "+len);
- //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
- .getChannel(), 0, file.length()));
+ .getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
- //logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
- //logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
index 94c034cad0..9f7ced44cf 100755
--- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
+++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
@@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1e3f1ebfaf..5e465fa22c 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,13 +20,11 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
-import akka.remote._
import akka.util.Duration
@@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
+ }
+
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 29968c273c..aaab717bcf 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -187,10 +187,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 1586dff254..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return file.getAbsolutePath
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d84f5968df..e58ff37b9b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -52,13 +52,14 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setDAGScheduler(this)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 40baea69e8..24d97da6eb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -167,8 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
@@ -191,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
+ buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1fe0d0e4e2..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- def run(attemptId: Long): T = {
+ final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 768e5a647f..e6329cbd47 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -102,18 +102,19 @@ private[spark] class BlockManager(
}
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -512,16 +513,20 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
+ val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
+ diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
+ myInfo.markReady(writer.fileSegment().length)
})
writer
}
@@ -862,13 +867,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 76c92cefd8..32d2dd0694 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -59,12 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
*/
- def size(): Long
+ def fileSegment(): FileSegment
/**
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private var initialPosition = 0L
+ private var lastValidPosition = 0L
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ initialPosition = channel.position
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..bcb58ad946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ // Stores only Blocks which have been specifically mapped to segments of files
+ // (rather than the default, which maps a Block to a whole file).
+ // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
+ private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
+
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
+
+ addShutdownHook()
+
+ /**
+ * Creates a logical mapping from the given BlockId to a segment of a file.
+ * This will cause any accesses of the logical BlockId to be directed to the specified
+ * physical location.
+ */
+ def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
+ blockToFileSegmentMap.put(blockId, fileSegment)
+ }
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
+ blockToFileSegmentMap.get(blockId).get
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ /**
+ * Simply returns a File to place the given Block into. This does not physically create the file.
+ * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
+ * a unique filename.
+ */
+ def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
+ val actualFilename = if (filename == "") blockId.name else filename
+ val file = getFile(actualFilename)
+ if (!allowAppending && file.exists()) {
+ throw new IllegalStateException(
+ "Attempted to create file that already exists: " + actualFilename)
+ }
+ file
+ }
+
+ private def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ blockToFileSegmentMap.clearOldValues(cleanupTime)
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 2a9a3f61bd..a3c496f9e0 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,158 +17,25 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
- private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
- def timeWriting = _timeWriting
- private var _timeWriting = 0L
-
- private def callWithTiming(f: => Unit) = {
- val start = System.nanoTime()
- f
- _timeWriting += (System.nanoTime() - start)
- }
-
- def write(i: Int): Unit = callWithTiming(out.write(i))
- override def write(b: Array[Byte]) = callWithTiming(out.write(b))
- override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
- }
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
- private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var fos: FileOutputStream = null
- private var ts: TimeTrackingOutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
- private var _timeWriting = 0L
-
- override def open(): DiskBlockObjectWriter = {
- fos = new FileOutputStream(f, true)
- ts = new TimeTrackingOutputStream(fos)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(ts, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- if (syncWrites) {
- // Force outstanding writes to disk and track how long it takes
- objOut.flush()
- val start = System.nanoTime()
- fos.getFD.sync()
- _timeWriting += System.nanoTime() - start
- objOut.close()
- } else {
- objOut.close()
- }
-
- _timeWriting += ts.timeWriting
-
- channel = null
- bs = null
- fos = null
- ts = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
-
- // Only valid if called after close()
- override def timeWriting = _timeWriting
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
- }
-
override def getSize(blockId: BlockId): Long = {
- getFile(blockId).length()
+ diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@@ -177,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@@ -209,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -231,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@@ -249,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
+ }
false
}
}
override def contains(blockId: BlockId): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: BlockId): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId.name)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
- }
- }
-
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
- if (!blockId.isShuffle) null
- else DiskStore.this.getFile(blockId).getAbsolutePath
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f39fcd87fb..229178c095 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,12 +17,13 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.spark.serializer.Serializer
private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
-
+class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@@ -30,24 +31,61 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
+ * per reducer.
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
+ * it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+
+ var nextFileId = new AtomicInteger(0)
+ val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
}
- new ShuffleWriterGroup(mapId, writers)
+ new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ override def releaseWriters(group: ShuffleWriterGroup) {
+ recycleFileId(group.fileId)
}
}
}
+
+ private def getUnusedFileId(): Int = {
+ val fileId = unusedFileIds.poll()
+ if (fileId == null) nextFileId.getAndIncrement() else fileId
+ }
+
+ private def recycleFileId(fileId: Int) {
+ if (!consolidateShuffleFiles) { return } // ensures we always generate new file id
+ unusedFileIds.add(fileId)
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 5f30383fd0..1b074e5ec7 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -44,7 +44,7 @@ object StoragePerfTester {
}
buckets.writers.map {w =>
w.commit()
- total.addAndGet(w.size())
+ total.addAndGet(w.fileSegment().length)
w.close()
}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0ce1394c77..3f963727d9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6013320eaa..b7eb268bd5 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
- // As if we had two simulatenous fetch failures
+ // As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- val masterTracker = new MapOutputTracker()
+ val masterTracker = new MapOutputTrackerMaster()
masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2a2f828be6..00f2fdd657 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
- var mapOutputTracker: MapOutputTracker = null
+ var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@@ -99,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTracker()
+ mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
diff --git a/examples/pom.xml b/examples/pom.xml
index 15399a8a33..aee371fbc7 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -32,13 +32,20 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
+
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
@@ -81,9 +88,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <artifactId>kafka_2.9.2</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
new file mode 100644
index 0000000000..9a8e4209ed
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples;
+
+import java.util.Map;
+import java.util.HashMap;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import scala.Tuple2;
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: JavaKafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <zkQuorum> is a list of one or more zookeeper servers that make quorum
+ * <group> is the name of kafka consumer group
+ * <topics> is a list of one or more kafka topics to consume from
+ * <numThreads> is the number of threads the kafka consumer should use
+ *
+ * Example:
+ * `./run-example org.apache.spark.streaming.examples.JavaKafkaWordCount local[2] zoo01,zoo02,
+ * zoo03 my-consumer-group topic1,topic2 1`
+ */
+
+public class JavaKafkaWordCount {
+ public static void main(String[] args) {
+ if (args.length < 5) {
+ System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>");
+ System.exit(1);
+ }
+
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+
+ int numThreads = Integer.parseInt(args[4]);
+ Map<String, Integer> topicMap = new HashMap<String, Integer>();
+ String[] topics = args[3].split(",");
+ for (String topic: topics) {
+ topicMap.put(topic, numThreads);
+ }
+
+ JavaPairDStream<String, String> messages = ssc.kafkaStream(args[1], args[2], topicMap);
+
+ JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
+ @Override
+ public String call(Tuple2<String, String> tuple2) throws Exception {
+ return tuple2._2();
+ }
+ });
+
+ JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split(" "));
+ }
+ });
+
+ JavaPairDStream<String, Integer> wordCounts = words.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(String s) throws Exception {
+ return new Tuple2<String, Integer>(s, 1);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.print();
+ ssc.start();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
index 12f939d5a7..570ba4c81a 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
@@ -18,13 +18,11 @@
package org.apache.spark.streaming.examples
import java.util.Properties
-import kafka.message.Message
-import kafka.producer.SyncProducerConfig
+
import kafka.producer._
-import org.apache.spark.SparkContext
+
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.RawTextHelper._
/**
@@ -54,9 +52,10 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
+ val lines = ssc.kafkaStream(zkQuorum, group, topicpMap).map(_._2)
val words = lines.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ val wordCounts = words.map(x => (x, 1l))
+ .reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
ssc.start()
@@ -68,15 +67,16 @@ object KafkaWordCountProducer {
def main(args: Array[String]) {
if (args.length < 2) {
- System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.err.println("Usage: KafkaWordCountProducer <metadataBrokerList> <topic> " +
+ "<messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
- val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
+ val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
- props.put("zk.connect", zkQuorum)
+ props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
@@ -85,11 +85,13 @@ object KafkaWordCountProducer {
// Send some messages
while(true) {
val messages = (1 to messagesPerSec.toInt).map { messageNum =>
- (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
+ .mkString(" ")
+
+ new KeyedMessage[String, String](topic, str)
}.toArray
- println(messages.mkString(","))
- val data = new ProducerData[String, String](topic, messages)
- producer.send(data)
+
+ producer.send(messages: _*)
Thread.sleep(100)
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 965c4f3a63..17f480e3f0 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -278,13 +278,16 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
- "Akka Repository" at "http://repo.akka.io/releases/"
+ "Akka Repository" at "http://repo.akka.io/releases/",
+ "Apache repo" at "https://repository.apache.org/content/repositories/releases"
),
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
- "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
- "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty)
+ "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty),
+ "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
+ exclude("com.sun.jdmk", "jmxtools")
+ exclude("com.sun.jmx", "jmxri")
)
)
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
deleted file mode 100644
index 65f79925a4..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
+++ /dev/null
Binary files differ
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
deleted file mode 100644
index 29f45f4adb..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
+++ /dev/null
@@ -1 +0,0 @@
-18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
deleted file mode 100644
index e3bd62bac0..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
+++ /dev/null
@@ -1 +0,0 @@
-06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
deleted file mode 100644
index 082d35726a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
+++ /dev/null
@@ -1,9 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
- <modelVersion>4.0.0</modelVersion>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version>
- <description>POM was created from install:install-file</description>
-</project>
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
deleted file mode 100644
index 92c4132b5b..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
+++ /dev/null
@@ -1 +0,0 @@
-7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
deleted file mode 100644
index 8a1d8a097a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
+++ /dev/null
@@ -1 +0,0 @@
-d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
deleted file mode 100644
index 720cd51c2f..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<metadata>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <versioning>
- <release>0.7.2-spark</release>
- <versions>
- <version>0.7.2-spark</version>
- </versions>
- <lastUpdated>20130121015225</lastUpdated>
- </versioning>
-</metadata>
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
deleted file mode 100644
index a4ce5dc9e8..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
+++ /dev/null
@@ -1 +0,0 @@
-e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
deleted file mode 100644
index b869eaf2a6..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
+++ /dev/null
@@ -1 +0,0 @@
-2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file
diff --git a/streaming/pom.xml b/streaming/pom.xml
index bcbed1644a..339fcd2a39 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -32,10 +32,16 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
@@ -56,9 +62,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <artifactId>kafka_2.9.2</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.apache.flume</groupId>
@@ -76,17 +91,6 @@
</exclusions>
</dependency>
<dependency>
- <groupId>com.github.sgroschupf</groupId>
- <artifactId>zkclient</artifactId>
- <version>0.1</version>
- <exclusions>
- <exclusion>
- <groupId>org.jboss.netty</groupId>
- <artifactId>netty</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
- <dependency>
<groupId>org.twitter4j</groupId>
<artifactId>twitter4j-stream</artifactId>
<version>3.0.3</version>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 098081d245..ee265ab4e9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -256,10 +256,14 @@ class StreamingContext private (
groupId: String,
topics: Map[String, Int],
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
- ): DStream[String] = {
+ ): DStream[(String, String)] = {
val kafkaParams = Map[String, String](
- "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
- kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
+ "zookeeper.connect" -> zkQuorum, "group.id" -> groupId,
+ "zookeeper.connection.timeout.ms" -> "10000")
+ kafkaStream[String, String, kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](
+ kafkaParams,
+ topics,
+ storageLevel)
}
/**
@@ -270,12 +274,16 @@ class StreamingContext private (
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
*/
- def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
+ def kafkaStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ U <: kafka.serializer.Decoder[_]: Manifest,
+ T <: kafka.serializer.Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ): DStream[T] = {
- val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
+ ): DStream[(K, V)] = {
+ val inputStream = new KafkaInputDStream[K, V, U, T](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 54ba3e6025..6423b916b0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -141,7 +141,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
- : JavaDStream[String] = {
+ : JavaPairDStream[String, String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@@ -162,7 +162,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
groupId: String,
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[String] = {
+ : JavaPairDStream[String, String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@@ -171,25 +171,34 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param typeClass Type of RDD
- * @param decoderClass Type of kafka decoder
+ * @param keyTypeClass Key type of RDD
+ * @param valueTypeClass value type of RDD
+ * @param keyDecoderClass Type of kafka key decoder
+ * @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
- def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
- typeClass: Class[T],
- decoderClass: Class[D],
+ def kafkaStream[K, V, U <: kafka.serializer.Decoder[_], T <: kafka.serializer.Decoder[_]](
+ keyTypeClass: Class[K],
+ valueTypeClass: Class[V],
+ keyDecoderClass: Class[U],
+ valueDecoderClass: Class[T],
kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
- ssc.kafkaStream[T, D](
+ : JavaPairDStream[K, V] = {
+ implicit val keyCmt: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val valueCmt: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+
+ implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]]
+ implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]]
+
+ ssc.kafkaStream[K, V, U, T](
kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
index 51e913675d..a5de5e1fb5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
@@ -19,22 +19,18 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
+import org.apache.spark.streaming.StreamingContext
import java.util.Properties
import java.util.concurrent.Executors
import kafka.consumer._
-import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.Decoder
-import kafka.utils.{Utils, ZKGroupTopicDirs}
-import kafka.utils.ZkUtils._
+import kafka.utils.VerifiableProperties
import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient._
import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
/**
@@ -46,25 +42,32 @@ import scala.collection.JavaConversions._
* @param storageLevel RDD storage level.
*/
private[streaming]
-class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
+class KafkaInputDStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+ ) extends NetworkInputDStream[(K, V)](ssc_) with Logging {
-
- def getReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
- .asInstanceOf[NetworkReceiver[T]]
+ def getReceiver(): NetworkReceiver[(K, V)] = {
+ new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
+ .asInstanceOf[NetworkReceiver[(K, V)]]
}
}
private[streaming]
-class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
- kafkaParams: Map[String, String],
- topics: Map[String, Int],
- storageLevel: StorageLevel
+class KafkaReceiver[
+ K: ClassManifest,
+ V: ClassManifest,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
@@ -83,27 +86,34 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
+ logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id"))
// Kafka connection properties
val props = new Properties()
kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
+ logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect"))
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig)
- logInfo("Connected to " + kafkaParams("zk.connect"))
+ logInfo("Connected to " + kafkaParams("zookeeper.connect"))
// When autooffset.reset is defined, it is our responsibility to try and whack the
// consumer group zk node.
- if (kafkaParams.contains("autooffset.reset")) {
- tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
+ if (kafkaParams.contains("auto.offset.reset")) {
+ tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id"))
}
// Create Threads for each Topic/Message Stream we are listening
- val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
- val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
+ val keyDecoder = manifest[U].erasure.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[K]]
+ val valueDecoder = manifest[T].erasure.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[V]]
+
+ val topicMessageStreams = consumerConnector.createMessageStreams(
+ topics, keyDecoder, valueDecoder)
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
@@ -112,11 +122,12 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
}
// Handles Kafka Messages
- private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
+ private class MessageHandler[K: ClassManifest, V: ClassManifest](stream: KafkaStream[K, V])
+ extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
for (msgAndMetadata <- stream) {
- blockGenerator += msgAndMetadata.message
+ blockGenerator += (msgAndMetadata.key, msgAndMetadata.message)
}
}
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index c0d729ff87..dc01f1e8aa 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -1220,14 +1220,20 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
- JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
- JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+ JavaPairDStream<String, String> test1 = ssc.kafkaStream("localhost:12345", "group", topics);
+ JavaPairDStream<String, String> test2 = ssc.kafkaStream("localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK());
HashMap<String, String> kafkaParams = Maps.newHashMap();
- kafkaParams.put("zk.connect","localhost:12345");
- kafkaParams.put("groupid","consumer-group");
- JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
+ kafkaParams.put("zookeeper.connect","localhost:12345");
+ kafkaParams.put("group.id","consumer-group");
+ JavaPairDStream<String, String> test3 = ssc.kafkaStream(
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topics,
StorageLevel.MEMORY_AND_DISK());
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 42e3e51e3f..c29b75ece6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -268,8 +268,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
// Test specifying decoder
- val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
- val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
+ val test3 = ssc.kafkaStream[
+ String,
+ String,
+ kafka.serializer.StringDecoder,
+ kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8afb3e39cb..1a380ae714 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -265,11 +265,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val env = new HashMap[String, String]()
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
- Apps.addToEnvironment(env, Environment.CLASSPATH.name,
- Environment.PWD.$() + Path.SEPARATOR + "*")
-
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_JAR_PATH") =
localResources("spark.jar").getResource().getScheme.toString() + "://" +
@@ -451,4 +447,30 @@ object Client {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
+
+ def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ // If log4j present, ensure ours overrides all others
+ if (addLog4j) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "log4j.properties")
+ }
+ // normally the users app.jar is last in case conflicts with spark jars
+ val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
+ .toBoolean
+ if (userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "app.jar")
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "spark.jar")
+ Client.populateHadoopClasspath(conf, env)
+
+ if (!userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "app.jar")
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "*")
+ }
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index 8dac9e02ac..ba352daac4 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -121,7 +121,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
- " org.apache.spark.executor.StandaloneExecutorBackend " +
+ " org.apache.spark.executor.CoarseGrainedExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
@@ -216,10 +216,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
- Apps.addToEnvironment(env, Environment.CLASSPATH.name,
- Environment.PWD.$() + Path.SEPARATOR + "*")
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))