aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-05-30 14:50:06 -0700
committerReynold Xin <rxin@cs.berkeley.edu>2013-05-30 14:50:06 -0700
commitef77bb73c66ce938e409cbbac32b67badaa5c57d (patch)
treeed052460e9f82d887b165c876a13ab01414622f1
parent8cb817820f134eac19985ee86cbc92f6bc2f2f4d (diff)
parent3b0cd173430188254e068bd72890e86b864792cd (diff)
downloadspark-ef77bb73c66ce938e409cbbac32b67badaa5c57d.tar.gz
spark-ef77bb73c66ce938e409cbbac32b67badaa5c57d.tar.bz2
spark-ef77bb73c66ce938e409cbbac32b67badaa5c57d.zip
Merge pull request #627 from shivaram/master
Netty and shuffle bug fixes
-rw-r--r--core/src/main/java/spark/network/netty/FileServer.java45
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala3
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala25
3 files changed, 51 insertions, 22 deletions
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
index dd3f12561c..dd3a557ae5 100644
--- a/core/src/main/java/spark/network/netty/FileServer.java
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -37,29 +37,33 @@ class FileServer {
.childHandler(new FileServerChannelInitializer(pResolver));
// Start the server.
channelFuture = bootstrap.bind(addr);
- this.port = addr.getPort();
+ try {
+ // Get the address we bound to.
+ InetSocketAddress boundAddress =
+ ((InetSocketAddress) channelFuture.sync().channel().localAddress());
+ this.port = boundAddress.getPort();
+ } catch (InterruptedException ie) {
+ this.port = 0;
+ }
}
/**
* Start the file server asynchronously in a new thread.
*/
public void start() {
- try {
- blockingThread = new Thread() {
- public void run() {
- try {
- Channel channel = channelFuture.sync().channel();
- channel.closeFuture().sync();
- } catch (InterruptedException e) {
- LOG.error("File server start got interrupted", e);
- }
+ blockingThread = new Thread() {
+ public void run() {
+ try {
+ channelFuture.channel().closeFuture().sync();
+ LOG.info("FileServer exiting");
+ } catch (InterruptedException e) {
+ LOG.error("File server start got interrupted", e);
}
- };
- blockingThread.setDaemon(true);
- blockingThread.start();
- } finally {
- bootstrap.shutdown();
- }
+ // NOTE: bootstrap is shutdown in stop()
+ }
+ };
+ blockingThread.setDaemon(true);
+ blockingThread.start();
}
public int getPort() {
@@ -67,17 +71,16 @@ class FileServer {
}
public void stop() {
- if (blockingThread != null) {
- blockingThread.stop();
- blockingThread = null;
- }
+ // Close the bound channel.
if (channelFuture != null) {
- channelFuture.channel().closeFuture();
+ channelFuture.channel().close();
channelFuture = null;
}
+ // Shutdown bootstrap.
if (bootstrap != null) {
bootstrap.shutdown();
bootstrap = null;
}
+ // TODO: Shutdown all accepted channels as well ?
}
}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 57d4dafefc..c7281200e7 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
@@ -68,6 +70,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def revertPartialWrites() {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
+ objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 4e50ae2ca9..b967016cf7 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -305,9 +305,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(c.partitioner.get === p)
}
+ test("shuffle non-zero block size") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val NUM_BLOCKS = 3
+
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
+ classOf[spark.KryoSerializer].getName)
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+
+ assert(c.count === 10)
+
+ // All blocks must have non-zero size
+ (0 until NUM_BLOCKS).foreach { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ assert(statuses.forall(s => s._2 > 0))
+ }
+ }
+
test("shuffle serializer") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[1,2,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(x, new ShuffleSuite.NonJavaSerializableClass(x * 2))