From 9e01dcc6446f8648e61062f8afe62589b9d4b5ab Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Feb 2016 17:25:07 -0800 Subject: [SPARK-13529][BUILD] Move network/* modules into common/network-* ## What changes were proposed in this pull request? As the title says, this moves the three modules currently in network/ into common/network-*. This removes one top level, non-user-facing folder. ## How was this patch tested? Compilation and existing tests. We should run both SBT and Maven. Author: Reynold Xin Closes #11409 from rxin/SPARK-13529. --- .../spark/network/sasl/SaslIntegrationSuite.java | 294 +++++++++++++++++++ .../shuffle/BlockTransferMessagesSuite.java | 44 +++ .../shuffle/ExternalShuffleBlockHandlerSuite.java | 127 +++++++++ .../shuffle/ExternalShuffleBlockResolverSuite.java | 156 ++++++++++ .../shuffle/ExternalShuffleCleanupSuite.java | 149 ++++++++++ .../shuffle/ExternalShuffleIntegrationSuite.java | 301 ++++++++++++++++++++ .../shuffle/ExternalShuffleSecuritySuite.java | 124 ++++++++ .../shuffle/OneForOneBlockFetcherSuite.java | 176 ++++++++++++ .../network/shuffle/RetryingBlockFetcherSuite.java | 313 +++++++++++++++++++++ .../network/shuffle/TestShuffleDataContext.java | 117 ++++++++ 10 files changed, 1801 insertions(+) create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java (limited to 'common/network-shuffle/src/test') diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000..0ea631ea14 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.shuffle.BlockFetchingListener; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; +import org.apache.spark.network.shuffle.OneForOneBlockFetcher; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private static final long TIMEOUT_MS = 10_000; + + static TransportServer server; + static TransportConf conf; + static TransportContext context; + static SecretKeyHolder secretKeyHolder; + + TransportClientFactory clientFactory; + + @BeforeClass + public static void beforeAll() throws IOException { + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + context = new TransportContext(conf, new TestRpcHandler()); + + secretKeyHolder = mock(SecretKeyHolder.class); + when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1"); + when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2"); + when(secretKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(secretKeyHolder.getSecretKey(anyString())).thenReturn("correct-password"); + + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + server = context.createServer(Arrays.asList(bootstrap)); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), TIMEOUT_MS); + assertEquals(msg, JavaUtils.bytesToString(resp)); + } + + @Test + public void testBadClient() { + SecretKeyHolder badKeyHolder = mock(SecretKeyHolder.class); + when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); + when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + fail("Connection should have failed."); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** + * This test is not actually testing SASL behavior, but testing that the shuffle service + * performs correct authorization checks based on the SASL authentication data. + */ + @Test + public void testAppIsolation() throws Exception { + // Start a new server with the correct RPC handler to serve block data. + ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class); + ExternalShuffleBlockHandler blockHandler = new ExternalShuffleBlockHandler( + new OneForOneStreamManager(), blockResolver); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + TransportContext blockServerContext = new TransportContext(conf, blockHandler); + TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap)); + + TransportClient client1 = null; + TransportClient client2 = null; + TransportClientFactory clientFactory2 = null; + try { + // Create a client, and make a request to fetch blocks from a different app. + clientFactory = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + client1 = clientFactory.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + final AtomicReference exception = new AtomicReference<>(); + + BlockFetchingListener listener = new BlockFetchingListener() { + @Override + public synchronized void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + notifyAll(); + } + + @Override + public synchronized void onBlockFetchFailure(String blockId, Throwable t) { + exception.set(t); + notifyAll(); + } + }; + + String[] blockIds = new String[] { "shuffle_2_3_4", "shuffle_6_7_8" }; + OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", + blockIds, listener); + synchronized (listener) { + fetcher.start(); + listener.wait(); + } + checkSecurityException(exception.get()); + + // Register an executor so that the next steps work. + ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo( + new String[] { System.getProperty("java.io.tmpdir") }, 1, "sort"); + RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); + client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); + + // Make a successful request to fetch blocks, which creates a new stream. But do not actually + // fetch any blocks, to keep the stream open. + OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); + ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); + StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); + long streamId = stream.streamId; + + // Create a second client, authenticated with a different app ID, and try to read from + // the stream created for the previous app. + clientFactory2 = blockServerContext.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + client2 = clientFactory2.createClient(TestUtils.getLocalHost(), + blockServer.getPort()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public synchronized void onSuccess(int chunkIndex, ManagedBuffer buffer) { + notifyAll(); + } + + @Override + public synchronized void onFailure(int chunkIndex, Throwable t) { + exception.set(t); + notifyAll(); + } + }; + + exception.set(null); + synchronized (callback) { + client2.fetchChunk(streamId, 0, callback); + callback.wait(); + } + checkSecurityException(exception.get()); + } finally { + if (client1 != null) { + client1.close(); + } + if (client2 != null) { + client2.close(); + } + if (clientFactory2 != null) { + clientFactory2.close(); + } + blockServer.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } + + private void checkSecurityException(Throwable t) { + assertNotNull("No exception was caught.", t); + assertTrue("Expected SecurityException.", + t.getMessage().contains(SecurityException.class.getName())); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java new file mode 100644 index 0000000000..86c8609e70 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.shuffle.protocol.*; + +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { + @Test + public void serializeOpenShuffleBlocks() { + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); + } + + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); + assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java new file mode 100644 index 0000000000..9379412155 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; + +public class ExternalShuffleBlockHandlerSuite { + TransportClient client = mock(TransportClient.class); + + OneForOneStreamManager streamManager; + ExternalShuffleBlockResolver blockResolver; + RpcHandler handler; + + @Before + public void beforeEach() { + streamManager = mock(OneForOneStreamManager.class); + blockResolver = mock(ExternalShuffleBlockResolver.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockResolver); + } + + @Test + public void testRegisterExecutor() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); + ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); + handler.receive(client, registerMessage, callback); + verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); + + verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void testOpenShuffleBlocks() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + .toByteBuffer(); + handler.receive(client, openBlocks, callback); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); + + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure((Throwable) any()); + + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(2, handle.numChunks); + + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + Iterator buffers = stream.getValue(); + assertEquals(block0Marker, buffers.next()); + assertEquals(block1Marker, buffers.next()); + assertFalse(buffers.hasNext()); + } + + @Test + public void testBadMessages() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 }); + try { + handler.receive(client, unserializableMsg, callback); + fail("Should have thrown"); + } catch (Exception e) { + // pass + } + + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); + try { + handler.receive(client, unexpectedMsg, callback); + fail("Should have thrown"); + } catch (UnsupportedOperationException e) { + // pass + } + + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + verify(callback, never()).onFailure(any(Throwable.class)); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java new file mode 100644 index 0000000000..60a1b8b045 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.CharStreams; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockResolverSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + static TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + // Unregistered executor + try { + resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + resolver.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("sort")); + try { + resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("sort")); + + InputStream block0Stream = + resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("hash")); + + InputStream block0Stream = + resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } + + @Test + public void jsonSerializationOfExecutorRegistration() throws IOException { + ObjectMapper mapper = new ObjectMapper(); + AppExecId appId = new AppExecId("foo", "bar"); + String appIdJson = mapper.writeValueAsString(appId); + AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); + assertEquals(parsedAppId, appId); + + ExecutorShuffleInfo shuffleInfo = + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + String shuffleJson = mapper.writeValueAsString(shuffleInfo); + ExecutorShuffleInfo parsedShuffleInfo = + mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); + assertEquals(parsedShuffleInfo, shuffleInfo); + + // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // its not legacy yet, but keeping this here in case anybody changes it + String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; + assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); + String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000..532d7ab8d0 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + resolver.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + resolver.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + + resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + resolver.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java new file mode 100644 index 0000000000..5e706bf401 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleIntegrationSuite { + + static String APP_ID = "app-id"; + static String SORT_MANAGER = "sort"; + static String HASH_MANAGER = "hash"; + + // Executor 0 is sort-based + static TestShuffleDataContext dataContext0; + // Executor 1 is hash-based + static TestShuffleDataContext dataContext1; + + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + + static byte[][] exec0Blocks = new byte[][] { + new byte[123], + new byte[12345], + new byte[1234567], + }; + + static byte[][] exec1Blocks = new byte[][] { + new byte[321], + new byte[54321], + }; + + @BeforeClass + public static void beforeAll() throws IOException { + Random rand = new Random(); + + for (byte[] block : exec0Blocks) { + rand.nextBytes(block); + } + for (byte[] block: exec1Blocks) { + rand.nextBytes(block); + } + + dataContext0 = new TestShuffleDataContext(2, 5); + dataContext0.create(); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + + dataContext1 = new TestShuffleDataContext(6, 2); + dataContext1.create(); + dataContext1.insertHashShuffleData(1, 0, exec1Blocks); + + conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(conf, null); + TransportContext transportContext = new TransportContext(conf, handler); + server = transportContext.createServer(); + } + + @AfterClass + public static void afterAll() { + dataContext0.cleanup(); + dataContext1.cleanup(); + server.close(); + } + + @After + public void afterEach() { + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); + } + + class FetchResult { + public Set successBlocks; + public Set failedBlocks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, server.getPort()); + } + + // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, + // to allow connecting to invalid servers. + private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + final FetchResult res = new FetchResult(); + res.successBlocks = Collections.synchronizedSet(new HashSet()); + res.failedBlocks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + final Semaphore requestsRemaining = new Semaphore(0); + + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } + } + } + }); + + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void testFetchOneSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchHash() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); + execFetch.releaseBuffers(); + } + + @Test + public void testFetchWrongShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchInvalidShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongBlockId() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "rdd_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNonexistent() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_2_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); + // Both still fail, as we start by checking for all block. + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchUnregisteredExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-2", + new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNoServer() throws Exception { + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } + } + + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + client.init(APP_ID); + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), + executorId, executorInfo); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000..08ddb3755b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() throws IOException { + TransportContext context = + new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, + new TestSecretKeyHolder("my-app-id", "secret")); + this.server = context.createServer(Arrays.asList(bootstrap)); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() throws IOException { + validate("my-app-id", "secret", false); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret", false); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret", false); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testEncryption() throws IOException { + validate("my-app-id", "secret", true); + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey, boolean encrypt) throws IOException { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java new file mode 100644 index 0000000000..2590b9ce4c --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; + +public class OneForOneBlockFetcherSuite { + @Test + public void testFetchOne() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testFetchThree() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + } + } + + @Test + public void testFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", null); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // Each failure will cause a failure to be invoked in all remaining block fetches. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testFailureAndSuccess() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // We may call both success and failure for the same block. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testEmptyBlockFetch() { + try { + fetchBlocks(Maps.newLinkedHashMap()); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Zero-sized blockIds array", e.getMessage()); + } + } + + /** + * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which + * simply returns the given (BlockId, Block) pairs. + * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order + * that they were inserted in. + * + * If a block's buffer is "null", an exception will be thrown instead. + */ + private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + TransportClient client = mock(TransportClient.class); + BlockFetchingListener listener = mock(BlockFetchingListener.class); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); + + // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); + return null; + } + }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); + + // Respond to each chunk request with a single buffer from our blocks array. + final AtomicInteger expectedChunkIndex = new AtomicInteger(0); + final Iterator blockIterator = blocks.values().iterator(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); + } + return null; + } + }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + + fetcher.start(); + return listener; + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000..3a6ef0d3f8 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWait", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWait"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + List> interactions = Arrays.asList( + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build() + ); + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private static void performInteractions(List> interactions, + BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java new file mode 100644 index 0000000000..7ac1ca128a --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import com.google.common.io.Closeables; +import com.google.common.io.Files; + +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + +/** + * Manages some sort- and hash-based shuffle data, including the creation + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockResolver}. + */ +public class TestShuffleDataContext { + public final String[] localDirs; + public final int subDirsPerLocalDir; + + public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { + this.localDirs = new String[numLocalDirs]; + this.subDirsPerLocalDir = subDirsPerLocalDir; + } + + public void create() { + for (int i = 0; i < localDirs.length; i ++) { + localDirs[i] = Files.createTempDir().getAbsolutePath(); + + for (int p = 0; p < subDirsPerLocalDir; p ++) { + new File(localDirs[i], String.format("%02x", p)).mkdirs(); + } + } + } + + public void cleanup() { + for (String localDir : localDirs) { + deleteRecursively(new File(localDir)); + } + } + + /** Creates reducer blocks in a sort-based data format within our local dirs. */ + public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + + OutputStream dataStream = null; + DataOutputStream indexStream = null; + boolean suppressExceptionsDuringClose = true; + + try { + dataStream = new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; + indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + suppressExceptionsDuringClose = false; + } finally { + Closeables.close(dataStream, suppressExceptionsDuringClose); + Closeables.close(indexStream, suppressExceptionsDuringClose); + } + } + + /** Creates reducer blocks in a hash-based data format within our local dirs. */ + public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + for (int i = 0; i < blocks.length; i ++) { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; + Files.write(blocks[i], + ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId)); + } + } + + /** + * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this + * context's directories. + */ + public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } + + private static void deleteRecursively(File f) { + assert f != null; + if (f.isDirectory()) { + File[] children = f.listFiles(); + if (children != null) { + for (File child : children) { + deleteRecursively(child); + } + } + } + f.delete(); + } +} -- cgit v1.2.3