aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java28
-rw-r--r--network/common/src/test/java/org/apache/spark/network/StreamSuite.java23
5 files changed, 75 insertions, 24 deletions
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 68701f609f..c8fa870f50 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -27,7 +27,7 @@ import javax.annotation.Nullable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
-import scala.util.{DynamicVariable, Failure, Success}
+import scala.util.{DynamicVariable, Failure, Success, Try}
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SecurityManager, SparkConf}
@@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv(
@volatile private var error: Throwable = _
- def setError(e: Throwable): Unit = error = e
+ def setError(e: Throwable): Unit = {
+ error = e
+ source.close()
+ }
override def read(dst: ByteBuffer): Int = {
- if (error != null) {
- throw error
+ val result = if (error == null) {
+ Try(source.read(dst))
+ } else {
+ Failure(error)
+ }
+
+ result match {
+ case Success(bytesRead) => bytesRead
+ case Failure(error) => throw error
}
- source.read(dst)
}
override def close(): Unit = source.close()
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
index eb1d2604fb..a2768b4252 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala
@@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype")
}
- require(file != null, s"File not found: $streamId")
+ require(file != null && file.isFile(), s"File not found: $streamId")
new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 2b664c6313..6cc958a5f6 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val tempDir = Utils.createTempDir()
val file = new File(tempDir, "file")
Files.write(UUID.randomUUID().toString(), file, UTF_8)
+ val empty = new File(tempDir, "empty")
+ Files.write("", empty, UTF_8);
val jar = new File(tempDir, "jar")
Files.write(UUID.randomUUID().toString(), jar, UTF_8)
val fileUri = env.fileServer.addFile(file)
+ val emptyUri = env.fileServer.addFile(empty)
val jarUri = env.fileServer.addJar(jar)
val destDir = Utils.createTempDir()
- val destFile = new File(destDir, file.getName())
- val destJar = new File(destDir, jar.getName())
-
val sm = new SecurityManager(conf)
val hc = SparkHadoopUtil.get.conf
- Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
- Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
- assert(Files.equal(file, destFile))
- assert(Files.equal(jar, destJar))
+ val files = Seq(
+ (file, fileUri),
+ (empty, emptyUri),
+ (jar, jarUri))
+ files.foreach { case (f, uri) =>
+ val destFile = new File(destDir, f.getName())
+ Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)
+ assert(Files.equal(f, destFile))
+ }
+
+ // Try to download files that do not exist.
+ Seq("files", "jars").foreach { root =>
+ intercept[Exception] {
+ val uri = env.address.toSparkURL + s"/$root/doesNotExist"
+ Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)
+ }
+ }
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index be181e0660..4c15045363 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -185,16 +185,24 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
- StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
- callback);
- try {
- TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
- channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
- frameDecoder.setInterceptor(interceptor);
- streamActive = true;
- } catch (Exception e) {
- logger.error("Error installing stream handler.", e);
- deactivateStream();
+ if (resp.byteCount > 0) {
+ StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
+ callback);
+ try {
+ TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+ channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+ frameDecoder.setInterceptor(interceptor);
+ streamActive = true;
+ } catch (Exception e) {
+ logger.error("Error installing stream handler.", e);
+ deactivateStream();
+ }
+ } else {
+ try {
+ callback.onComplete(resp.streamId);
+ } catch (Exception e) {
+ logger.warn("Error in stream handler onComplete().", e);
+ }
}
} else {
logger.error("Could not find callback for StreamResponse.");
diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
index 00158fd081..538f3efe8d 100644
--- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -51,13 +51,14 @@ import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class StreamSuite {
- private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" };
+ private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" };
private static TransportServer server;
private static TransportClientFactory clientFactory;
private static File testFile;
private static File tempDir;
+ private static ByteBuffer emptyBuffer;
private static ByteBuffer smallBuffer;
private static ByteBuffer largeBuffer;
@@ -73,6 +74,7 @@ public class StreamSuite {
@BeforeClass
public static void setUp() throws Exception {
tempDir = Files.createTempDir();
+ emptyBuffer = createBuffer(0);
smallBuffer = createBuffer(100);
largeBuffer = createBuffer(100000);
@@ -103,6 +105,8 @@ public class StreamSuite {
return new NioManagedBuffer(largeBuffer);
case "smallBuffer":
return new NioManagedBuffer(smallBuffer);
+ case "emptyBuffer":
+ return new NioManagedBuffer(emptyBuffer);
case "file":
return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
default:
@@ -139,6 +143,18 @@ public class StreamSuite {
}
@Test
+ public void testZeroLengthStream() throws Throwable {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
+ task.run();
+ task.check();
+ } finally {
+ client.close();
+ }
+ }
+
+ @Test
public void testSingleStream() throws Throwable {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
@@ -226,6 +242,11 @@ public class StreamSuite {
outFile = File.createTempFile("data", ".tmp", tempDir);
out = new FileOutputStream(outFile);
break;
+ case "emptyBuffer":
+ baos = new ByteArrayOutputStream();
+ out = baos;
+ srcBuffer = emptyBuffer;
+ break;
default:
throw new IllegalArgumentException(streamId);
}