aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/network/netty/ShuffleCopier.scala')
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala82
1 files changed, 58 insertions, 24 deletions
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
index a91f5a886d..b01f6369f6 100644
--- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package spark.network.netty
import java.util.concurrent.Executors
@@ -9,19 +26,36 @@ import io.netty.util.CharsetUtil
import spark.Logging
import spark.network.ConnectionManagerId
+import scala.collection.JavaConverters._
+
private[spark] class ShuffleCopier extends Logging {
- def getBlock(cmId: ConnectionManagerId, blockId: String,
+ def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
- val fc = new FileClient(handler)
- fc.init()
- fc.connect(cmId.host, cmId.port)
- fc.sendRequest(blockId)
- fc.waitForClose()
- fc.close()
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
@@ -44,20 +78,18 @@ private[spark] object ShuffleCopier extends Logging {
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- }
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
- logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
}
- def runGetBlock(host:String, port:Int, file:String){
- val handler = new ShuffleClientHandler(echoResultCollectCallBack)
- val fc = new FileClient(handler)
- fc.init();
- fc.connect(host, port)
- fc.sendRequest(file)
- fc.waitForClose();
- fc.close()
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
}
def main(args: Array[String]) {
@@ -71,14 +103,16 @@ private[spark] object ShuffleCopier extends Logging {
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
- for (i <- Range(0, threads)) {
- val runnable = new Runnable() {
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
def run() {
- runGetBlock(host, port, file)
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
}
- }
- copiers.execute(runnable)
- }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
copiers.shutdown
+ System.exit(0)
}
}