diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-02-02 00:25:54 -0800 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2011-02-02 00:25:54 -0800 |
commit | ec28b607fd58489024ea7a6e801a97507036c1b2 (patch) | |
tree | a891bbc8b806576ad7893279d38483a183a8691d | |
parent | 7f74ee99f68911bcd471f8b428e5c2055d1f95b0 (diff) | |
parent | 817e7223213b8d4ef8c783acca1345ef03e97f22 (diff) | |
download | spark-ec28b607fd58489024ea7a6e801a97507036c1b2.tar.gz spark-ec28b607fd58489024ea7a6e801a97507036c1b2.tar.bz2 spark-ec28b607fd58489024ea7a6e801a97507036c1b2.zip |
Merge branch 'master' into sbt
Conflicts:
Makefile
core/src/main/java/spark/compress/lzf/LZF.java
core/src/main/java/spark/compress/lzf/LZFInputStream.java
core/src/main/java/spark/compress/lzf/LZFOutputStream.java
core/src/main/native/spark_compress_lzf_LZF.c
run
-rw-r--r-- | LICENSE | 27 | ||||
-rw-r--r-- | core/lib/compress-lzf-0.6.0/LICENSE | 11 | ||||
-rw-r--r-- | core/lib/compress-lzf-0.6.0/compress-lzf-0.6.0.jar | bin | 0 -> 14497 bytes | |||
-rw-r--r-- | core/src/main/java/spark/compress/lzf/LZF.java | 27 | ||||
-rw-r--r-- | core/src/main/java/spark/compress/lzf/LZFInputStream.java | 180 | ||||
-rw-r--r-- | core/src/main/java/spark/compress/lzf/LZFOutputStream.java | 85 | ||||
-rw-r--r-- | core/src/main/native/.gitignore | 3 | ||||
-rw-r--r-- | core/src/main/native/Makefile | 40 | ||||
-rw-r--r-- | core/src/main/native/spark_compress_lzf_LZF.c | 90 | ||||
-rw-r--r-- | core/src/main/scala/spark/BitTorrentBroadcast.scala | 1233 | ||||
-rw-r--r-- | core/src/main/scala/spark/Broadcast.scala | 839 | ||||
-rw-r--r-- | core/src/main/scala/spark/ChainedBroadcast.scala | 870 | ||||
-rw-r--r-- | core/src/main/scala/spark/DfsBroadcast.scala | 132 | ||||
-rw-r--r-- | core/src/main/scala/spark/MesosScheduler.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 6 | ||||
-rw-r--r-- | project/build/SparkProject.scala | 20 | ||||
-rwxr-xr-x | run | 1 |
17 files changed, 2371 insertions, 1195 deletions
diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..d17afa1fc6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2010, Regents of the University of California. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/core/lib/compress-lzf-0.6.0/LICENSE b/core/lib/compress-lzf-0.6.0/LICENSE new file mode 100644 index 0000000000..c5da4e1348 --- /dev/null +++ b/core/lib/compress-lzf-0.6.0/LICENSE @@ -0,0 +1,11 @@ +Copyright 2009-2010 Ning, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not +use this file except in compliance with the License. You may obtain a copy of +the License at http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS,WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations under +the License. diff --git a/core/lib/compress-lzf-0.6.0/compress-lzf-0.6.0.jar b/core/lib/compress-lzf-0.6.0/compress-lzf-0.6.0.jar Binary files differnew file mode 100644 index 0000000000..6cb5c4c92b --- /dev/null +++ b/core/lib/compress-lzf-0.6.0/compress-lzf-0.6.0.jar diff --git a/core/src/main/java/spark/compress/lzf/LZF.java b/core/src/main/java/spark/compress/lzf/LZF.java deleted file mode 100644 index 294a0494ec..0000000000 --- a/core/src/main/java/spark/compress/lzf/LZF.java +++ /dev/null @@ -1,27 +0,0 @@ -package spark.compress.lzf; - -public class LZF { - private static boolean loaded; - - static { - try { - System.loadLibrary("spark_native"); - loaded = true; - } catch(Throwable t) { - System.out.println("Failed to load native LZF library: " + t.toString()); - loaded = false; - } - } - - public static boolean isLoaded() { - return loaded; - } - - public static native int compress( - byte[] in, int inOff, int inLen, - byte[] out, int outOff, int outLen); - - public static native int decompress( - byte[] in, int inOff, int inLen, - byte[] out, int outOff, int outLen); -} diff --git a/core/src/main/java/spark/compress/lzf/LZFInputStream.java b/core/src/main/java/spark/compress/lzf/LZFInputStream.java deleted file mode 100644 index 16bc687489..0000000000 --- a/core/src/main/java/spark/compress/lzf/LZFInputStream.java +++ /dev/null @@ -1,180 +0,0 @@ -package spark.compress.lzf; - -import java.io.EOFException; -import java.io.FilterInputStream; -import java.io.IOException; -import java.io.InputStream; - -public class LZFInputStream extends FilterInputStream { - private static final int MAX_BLOCKSIZE = 1024 * 64 - 1; - private static final int MAX_HDR_SIZE = 7; - - private byte[] inBuf; // Holds data to decompress (including header) - private byte[] outBuf; // Holds decompressed data to output - private int outPos; // Current position in outBuf - private int outSize; // Total amount of data in outBuf - - private boolean closed; - private boolean reachedEof; - - private byte[] singleByte = new byte[1]; - - public LZFInputStream(InputStream in) { - super(in); - if (in == null) - throw new NullPointerException(); - inBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; - outBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; - outPos = 0; - outSize = 0; - } - - private void ensureOpen() throws IOException { - if (closed) throw new IOException("Stream closed"); - } - - @Override - public int read() throws IOException { - ensureOpen(); - int count = read(singleByte, 0, 1); - return (count == -1 ? -1 : singleByte[0] & 0xFF); - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - ensureOpen(); - if ((off | len | (off + len) | (b.length - (off + len))) < 0) - throw new IndexOutOfBoundsException(); - - int totalRead = 0; - - // Start with the current block in outBuf, and read and decompress any - // further blocks necessary. Instead of trying to decompress directly to b - // when b is large, we always use outBuf as an intermediate holding space - // in case GetPrimitiveArrayCritical decides to copy arrays instead of - // pinning them, which would cause b to be copied repeatedly into C-land. - while (len > 0) { - if (outPos == outSize) { - readNextBlock(); - if (reachedEof) - return totalRead == 0 ? -1 : totalRead; - } - int amtToCopy = Math.min(outSize - outPos, len); - System.arraycopy(outBuf, outPos, b, off, amtToCopy); - off += amtToCopy; - len -= amtToCopy; - outPos += amtToCopy; - totalRead += amtToCopy; - } - - return totalRead; - } - - // Read len bytes from this.in to a buffer, stopping only if EOF is reached - private int readFully(byte[] b, int off, int len) throws IOException { - int totalRead = 0; - while (len > 0) { - int amt = in.read(b, off, len); - if (amt == -1) - break; - off += amt; - len -= amt; - totalRead += amt; - } - return totalRead; - } - - // Read the next block from the underlying InputStream into outBuf, - // setting outPos and outSize, or set reachedEof if the stream ends. - private void readNextBlock() throws IOException { - // Read first 5 bytes of header - int count = readFully(inBuf, 0, 5); - if (count == 0) { - reachedEof = true; - return; - } else if (count < 5) { - throw new EOFException("Truncated LZF block header"); - } - - // Check magic bytes - if (inBuf[0] != 'Z' || inBuf[1] != 'V') - throw new IOException("Wrong magic bytes in LZF block header"); - - // Read the block - if (inBuf[2] == 0) { - // Uncompressed block - read directly to outBuf - int size = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); - if (readFully(outBuf, 0, size) != size) - throw new EOFException("EOF inside LZF block"); - outPos = 0; - outSize = size; - } else if (inBuf[2] == 1) { - // Compressed block - read to inBuf and decompress - if (readFully(inBuf, 5, 2) != 2) - throw new EOFException("Truncated LZF block header"); - int csize = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); - int usize = ((inBuf[5] & 0xFF) << 8) | (inBuf[6] & 0xFF); - if (readFully(inBuf, 7, csize) != csize) - throw new EOFException("Truncated LZF block"); - if (LZF.decompress(inBuf, 7, csize, outBuf, 0, usize) != usize) - throw new IOException("Corrupt LZF data stream"); - outPos = 0; - outSize = usize; - } else { - throw new IOException("Unknown block type in LZF block header"); - } - } - - /** - * Returns 0 after EOF has been reached, otherwise always return 1. - * - * Programs should not count on this method to return the actual number - * of bytes that could be read without blocking. - */ - @Override - public int available() throws IOException { - ensureOpen(); - return reachedEof ? 0 : 1; - } - - // TODO: Skip complete chunks without decompressing them? - @Override - public long skip(long n) throws IOException { - ensureOpen(); - if (n < 0) - throw new IllegalArgumentException("negative skip length"); - byte[] buf = new byte[512]; - long skipped = 0; - while (skipped < n) { - int len = (int) Math.min(n - skipped, buf.length); - len = read(buf, 0, len); - if (len == -1) { - reachedEof = true; - break; - } - skipped += len; - } - return skipped; - } - - @Override - public void close() throws IOException { - if (!closed) { - in.close(); - closed = true; - } - } - - @Override - public boolean markSupported() { - return false; - } - - @Override - public void mark(int readLimit) {} - - @Override - public void reset() throws IOException { - throw new IOException("mark/reset not supported"); - } -} diff --git a/core/src/main/java/spark/compress/lzf/LZFOutputStream.java b/core/src/main/java/spark/compress/lzf/LZFOutputStream.java deleted file mode 100644 index 5f65e95d2a..0000000000 --- a/core/src/main/java/spark/compress/lzf/LZFOutputStream.java +++ /dev/null @@ -1,85 +0,0 @@ -package spark.compress.lzf; - -import java.io.FilterOutputStream; -import java.io.IOException; -import java.io.OutputStream; - -public class LZFOutputStream extends FilterOutputStream { - private static final int BLOCKSIZE = 1024 * 64 - 1; - private static final int MAX_HDR_SIZE = 7; - - private byte[] inBuf; // Holds input data to be compressed - private byte[] outBuf; // Holds compressed data to be written - private int inPos; // Current position in inBuf - - public LZFOutputStream(OutputStream out) { - super(out); - inBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; - outBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; - inPos = MAX_HDR_SIZE; - } - - @Override - public void write(int b) throws IOException { - inBuf[inPos++] = (byte) b; - if (inPos == inBuf.length) - compressAndSendBlock(); - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - if ((off | len | (off + len) | (b.length - (off + len))) < 0) - throw new IndexOutOfBoundsException(); - - // If we're given a large array, copy it piece by piece into inBuf and - // write one BLOCKSIZE at a time. This is done to prevent the JNI code - // from copying the whole array repeatedly if GetPrimitiveArrayCritical - // decides to copy instead of pinning. - while (inPos + len >= inBuf.length) { - int amtToCopy = inBuf.length - inPos; - System.arraycopy(b, off, inBuf, inPos, amtToCopy); - inPos += amtToCopy; - compressAndSendBlock(); - off += amtToCopy; - len -= amtToCopy; - } - - // Copy the remaining (incomplete) block into inBuf - System.arraycopy(b, off, inBuf, inPos, len); - inPos += len; - } - - @Override - public void flush() throws IOException { - if (inPos > MAX_HDR_SIZE) - compressAndSendBlock(); - out.flush(); - } - - // Send the data in inBuf, and reset inPos to start writing a new block. - private void compressAndSendBlock() throws IOException { - int us = inPos - MAX_HDR_SIZE; - int maxcs = us > 4 ? us - 4 : us; - int cs = LZF.compress(inBuf, MAX_HDR_SIZE, us, outBuf, MAX_HDR_SIZE, maxcs); - if (cs != 0) { - // Compression made the data smaller; use type 1 header - outBuf[0] = 'Z'; - outBuf[1] = 'V'; - outBuf[2] = 1; - outBuf[3] = (byte) (cs >> 8); - outBuf[4] = (byte) (cs & 0xFF); - outBuf[5] = (byte) (us >> 8); - outBuf[6] = (byte) (us & 0xFF); - out.write(outBuf, 0, 7 + cs); - } else { - // Compression didn't help; use type 0 header and uncompressed data - inBuf[2] = 'Z'; - inBuf[3] = 'V'; - inBuf[4] = 0; - inBuf[5] = (byte) (us >> 8); - inBuf[6] = (byte) (us & 0xFF); - out.write(inBuf, 2, 5 + us); - } - inPos = MAX_HDR_SIZE; - } -} diff --git a/core/src/main/native/.gitignore b/core/src/main/native/.gitignore deleted file mode 100644 index b21d5dd963..0000000000 --- a/core/src/main/native/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/libspark_native.dylib -/libspark_native.so -/spark_compress_lzf_LZF.h diff --git a/core/src/main/native/Makefile b/core/src/main/native/Makefile deleted file mode 100644 index 8975bb8593..0000000000 --- a/core/src/main/native/Makefile +++ /dev/null @@ -1,40 +0,0 @@ -CC = gcc -#JAVA_HOME = /usr/lib/jvm/java-6-sun -OS_NAME = linux - -CFLAGS += -fPIC -O3 -funroll-all-loops - -SPARK = ../../.. - -LZF = $(SPARK)/lib/liblzf-3.5 - -TARGET = $(SPARK)/target/scala_2.8.1 - -LIB = $(TARGET)/native/libspark_native.so - -OS_X_LIB = $(TARGET)/native/libspark_native.dylib - -all: $(LIB) - -spark_compress_lzf_LZF.h: $(TARGET)/classes/spark/compress/lzf/LZF.class -ifeq ($(JAVA_HOME),) - $(error JAVA_HOME is not set) -else - $(JAVA_HOME)/bin/javah -classpath $(SPARK)/target/scala_2.8.1/classes spark.compress.lzf.LZF -endif - -$(TARGET)/native: - mkdir -p $@ - -$(LIB): spark_compress_lzf_LZF.h spark_compress_lzf_LZF.c | $(TARGET)/native - $(CC) $(CFLAGS) -shared -o $@ spark_compress_lzf_LZF.c \ - -I $(JAVA_HOME)/include -I $(JAVA_HOME)/include/$(OS_NAME) \ - -I $(LZF) $(LZF)/lzf_c.c $(LZF)/lzf_d.c - -$(OS_X_LIB): $(LIB) - cp $< $@ - -clean: - rm -f spark_compress_lzf_LZF.h $(LIB) - -.PHONY: all clean diff --git a/core/src/main/native/spark_compress_lzf_LZF.c b/core/src/main/native/spark_compress_lzf_LZF.c deleted file mode 100644 index c2a59def3e..0000000000 --- a/core/src/main/native/spark_compress_lzf_LZF.c +++ /dev/null @@ -1,90 +0,0 @@ -#include "spark_compress_lzf_LZF.h" -#include <lzf.h> - - -/* Helper function to throw an exception */ -static void throwException(JNIEnv *env, const char* className) { - jclass cls = (*env)->FindClass(env, className); - if (cls != 0) /* If cls is null, an exception was already thrown */ - (*env)->ThrowNew(env, cls, ""); -} - - -/* - * Since LZF.compress() and LZF.decompress() have the same signatures - * and differ only in which lzf_ function they call, implement both in a - * single function and pass it a pointer to the correct lzf_ function. - */ -static jint callCompressionFunction - (unsigned int (*func)(const void *const, unsigned int, void *, unsigned int), - JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, - jbyteArray outArray, jint outOff, jint outLen) -{ - jint inCap; - jint outCap; - jbyte *inData = 0; - jbyte *outData = 0; - jint ret; - jint s; - - if (!inArray || !outArray) { - throwException(env, "java/lang/NullPointerException"); - goto cleanup; - } - - inCap = (*env)->GetArrayLength(env, inArray); - outCap = (*env)->GetArrayLength(env, outArray); - - // Check if any of the offset/length pairs is invalid; we do this by OR'ing - // things we don't want to be negative and seeing if the result is negative - s = inOff | inLen | (inOff + inLen) | (inCap - (inOff + inLen)) | - outOff | outLen | (outOff + outLen) | (outCap - (outOff + outLen)); - if (s < 0) { - throwException(env, "java/lang/IndexOutOfBoundsException"); - goto cleanup; - } - - inData = (*env)->GetPrimitiveArrayCritical(env, inArray, 0); - outData = (*env)->GetPrimitiveArrayCritical(env, outArray, 0); - - if (!inData || !outData) { - // Out of memory - JVM will throw OutOfMemoryError - goto cleanup; - } - - ret = func(inData + inOff, inLen, outData + outOff, outLen); - -cleanup: - if (inData) - (*env)->ReleasePrimitiveArrayCritical(env, inArray, inData, 0); - if (outData) - (*env)->ReleasePrimitiveArrayCritical(env, outArray, outData, 0); - - return ret; -} - -/* - * Class: spark_compress_lzf_LZF - * Method: compress - * Signature: ([B[B)I - */ -JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_compress - (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, - jbyteArray outArray, jint outOff, jint outLen) -{ - return callCompressionFunction(lzf_compress, env, cls, - inArray, inOff, inLen, outArray,outOff, outLen); -} - -/* - * Class: spark_compress_lzf_LZF - * Method: decompress - * Signature: ([B[B)I - */ -JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_decompress - (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, - jbyteArray outArray, jint outOff, jint outLen) -{ - return callCompressionFunction(lzf_decompress, env, cls, - inArray, inOff, inLen, outArray,outOff, outLen); -} diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala new file mode 100644 index 0000000000..96d3643ffd --- /dev/null +++ b/core/src/main/scala/spark/BitTorrentBroadcast.scala @@ -0,0 +1,1233 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + +import scala.collection.mutable.{ListBuffer, Map, Set} + +@serializable +class BitTorrentBroadcast[T] (@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + BitTorrentBroadcast.synchronized { + BitTorrentBroadcast.values.put (uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var hasBlocksBitVector: BitSet = null + @transient var numCopiesSent: Array[Int] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + + @transient var listOfSources = ListBuffer[SourceInfo] () + + @transient var serveMR: ServeMultipleRequests = null + + // Used only in Master + @transient var guideMR: GuideMultipleRequests = null + + // Used only in Workers + @transient var ttGuide: TalkToGuide = null + + @transient var rxSpeeds = new SpeedTracker + @transient var txSpeeds = new SpeedTracker + + @transient var hostAddress = InetAddress.getLocalHost.getHostAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast (): Unit = { + logInfo ("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now + // val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) + // out.writeObject (value_) + // out.close + // TODO: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = blockifyObject (value_, BitTorrentBroadcast.BlockSize) + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + // Guide has all the blocks + hasBlocksBitVector = new BitSet (totalBlocks) + hasBlocksBitVector.set (0, totalBlocks) + + // Guide still hasn't sent any block + numCopiesSent = new Array[Int] (totalBlocks) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon (true) + guideMR.start + logInfo ("GuideMultipleRequests started...") + + // Must always come AFTER guideMR is created + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + // Must always come AFTER serveMR is created + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Must always come AFTER listenPort is created + val masterSource = + SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes) + hasBlocksBitVector.synchronized { + masterSource.hasBlocksBitVector = hasBlocksBitVector + } + + // In the beginning, this is the only known source to Guide + listOfSources = listOfSources + masterSource + + // Register with the Tracker + BitTorrentBroadcast.registerValue (uuid, + SourceInfo (hostAddress, guidePort, totalBlocks, totalBytes)) + } + + private def readObject (in: ObjectInputStream): Unit = { + in.defaultReadObject + BitTorrentBroadcast.synchronized { + val cachedVal = BitTorrentBroadcast.values.get (uuid) + + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Only the first worker in a node can ever be inside this 'else' + initializeWorkerVariables + + logInfo ("Local host address: " + hostAddress) + + // Start local ServeMultipleRequests thread first + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast (uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = unBlockifyObject[T] + BitTorrentBroadcast.values.put (uuid, value_) + } else { + // TODO: This part won't work, cause HDFS writing is turned OFF + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + BitTorrentBroadcast.values.put(uuid, value_) + fileIn.close + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + // Initialize variables in the worker node. Master sends everything as 0/null + private def initializeWorkerVariables: Unit = { + arrayOfBlocks = null + hasBlocksBitVector = null + numCopiesSent = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + + listenPortLock = new Object + totalBlocksLock = new Object + + serveMR = null + ttGuide = null + + rxSpeeds = new SpeedTracker + txSpeeds = new SpeedTracker + + hostAddress = InetAddress.getLocalHost.getHostAddress + listenPort = -1 + + listOfSources = ListBuffer[SourceInfo] () + + stopBroadcast = false + } + + private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream (baos) + oos.writeObject (obj) + oos.close + baos.close + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream (byteArray) + + var blockNum = (byteArray.length / blockSize) + if (byteArray.length % blockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock] (blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, blockSize)) { + val thisBlockSize = Math.min (blockSize, byteArray.length - i) + var tempByteArray = new Array[Byte] (thisBlockSize) + val hasRead = bais.read (tempByteArray, 0, thisBlockSize) + + retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) + blockID += 1 + } + bais.close + + var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + private def unBlockifyObject[A]: A = { + var retByteArray = new Array[Byte] (totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BitTorrentBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject (retByteArray) + } + + private def byteArrayToObject[A] (bytes: Array[Byte]): A = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) + val retVal = in.readObject.asInstanceOf[A] + in.close + return retVal + } + + private def getLocalSourceInfo: SourceInfo = { + // Wait till hostName and listenPort are OK + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Wait till totalBlocks and totalBytes are OK + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + var localSourceInfo = SourceInfo (hostAddress, listenPort, totalBlocks, + totalBytes) + + localSourceInfo.hasBlocks = hasBlocks + + hasBlocksBitVector.synchronized { + localSourceInfo.hasBlocksBitVector = hasBlocksBitVector + } + + return localSourceInfo + } + + // Add new SourceInfo to the listOfSources. Update if it exists already. + // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance + private def addToListOfSources (newSourceInfo: SourceInfo): Unit = { + listOfSources.synchronized { + if (listOfSources.contains(newSourceInfo)) { + listOfSources = listOfSources - newSourceInfo + } + listOfSources = listOfSources + newSourceInfo + } + } + + private def addToListOfSources (newSourceInfos: ListBuffer[SourceInfo]): Unit = { + newSourceInfos.foreach { newSourceInfo => + addToListOfSources (newSourceInfo) + } + } + + class TalkToGuide (gInfo: SourceInfo) + extends Thread with Logging { + override def run: Unit = { + + // Keep exchaning information until all blocks have been received + while (hasBlocks < totalBlocks) { + talkOnce + Thread.sleep (BitTorrentBroadcast.ranGen.nextInt ( + BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) + + BitTorrentBroadcast.MinKnockInterval) + } + + // Talk one more time to let the Guide know of reception completion + talkOnce + } + + // Connect to Guide and send this worker's information + private def talkOnce: Unit = { + var clientSocketToGuide: Socket = null + var oosGuide: ObjectOutputStream = null + var oisGuide: ObjectInputStream = null + + clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) + oosGuide = new ObjectOutputStream (clientSocketToGuide.getOutputStream) + oosGuide.flush + oisGuide = new ObjectInputStream (clientSocketToGuide.getInputStream) + + // Send local information + oosGuide.writeObject(getLocalSourceInfo) + oosGuide.flush + + // Receive source information from Guide + var suitableSources = + oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] + logInfo("Received suitableSources from Master " + suitableSources) + + addToListOfSources (suitableSources) + + oisGuide.close + oosGuide.close + clientSocketToGuide.close + } + } + + def getGuideInfo (variableUUID: UUID): SourceInfo = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var gInfo: SourceInfo = SourceInfo ("", SourceInfo.TxOverGoToHDFS, + SourceInfo.UnusedParam, SourceInfo.UnusedParam) + + var retriesLeft = BitTorrentBroadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out GuideInfo + val clientSocketToTracker = + new Socket(BitTorrentBroadcast.MasterHostAddress, BitTorrentBroadcast.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream (clientSocketToTracker.getOutputStream) + oosTracker.flush + val oisTracker = + new ObjectInputStream (clientSocketToTracker.getInputStream) + + // Send UUID and receive GuideInfo + oosTracker.writeObject (uuid) + oosTracker.flush + gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] + } catch { + case e: Exception => { + logInfo ("getGuideInfo had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close + } + if (oosTracker != null) { + oosTracker.close + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close + } + } + + Thread.sleep (BitTorrentBroadcast.ranGen.nextInt ( + BitTorrentBroadcast.MaxKnockInterval - BitTorrentBroadcast.MinKnockInterval) + + BitTorrentBroadcast.MinKnockInterval) + + retriesLeft -= 1 + } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) + + logInfo ("Got this guidePort from Tracker: " + gInfo.listenPort) + return gInfo + } + + def receiveBroadcast (variableUUID: UUID): Boolean = { + val gInfo = getGuideInfo (variableUUID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS || + gInfo.listenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + // Setup initial states of variables + totalBlocks = gInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) + hasBlocksBitVector = new BitSet (totalBlocks) + numCopiesSent = new Array[Int] (totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = gInfo.totalBytes + + // Start ttGuide to periodically talk to the Guide + var ttGuide = new TalkToGuide (gInfo) + ttGuide.setDaemon (true) + ttGuide.start + logInfo ("TalkToGuide started...") + + // Start pController to run TalkToPeer threads + var pcController = new PeerChatterController + pcController.setDaemon (true) + pcController.start + logInfo ("PeerChatterController started...") + + // TODO: Must fix this. This might never break if broadcast fails. + // We should be able to break and send false. Also need to kill threads + while (hasBlocks < totalBlocks) { + Thread.sleep(BitTorrentBroadcast.MaxKnockInterval) + } + + return true + } + + class PeerChatterController + extends Thread with Logging { + private var peersNowTalking = ListBuffer[SourceInfo] () + // TODO: There is a possible bug with blocksInRequestBitVector when a + // certain bit is NOT unset upon failure resulting in an infinite loop. + private var blocksInRequestBitVector = new BitSet (totalBlocks) + + override def run: Unit = { + var threadPool = + Broadcast.newDaemonFixedThreadPool (BitTorrentBroadcast.MaxTxPeers) + + while (hasBlocks < totalBlocks) { + var numThreadsToCreate = + Math.min (listOfSources.size, BitTorrentBroadcast.MaxTxPeers) - + threadPool.getActiveCount + + while (hasBlocks < totalBlocks && numThreadsToCreate > 0) { + var peerToTalkTo = pickPeerToTalkTo + if (peerToTalkTo != null) { + threadPool.execute (new TalkToPeer (peerToTalkTo)) + + // Add to peersNowTalking. Remove in the thread. We have to do this + // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once + peersNowTalking.synchronized { + peersNowTalking = peersNowTalking + peerToTalkTo + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before starting some more threads + Thread.sleep (BitTorrentBroadcast.MinKnockInterval) + } + // Shutdown the thread pool + threadPool.shutdown + } + + // Right now picking the one that has the most blocks this peer wants + // Also picking peer randomly if no one has anything interesting + private def pickPeerToTalkTo: SourceInfo = { + var curPeer: SourceInfo = null + var curMax = 0 + + logInfo ("Picking peers to talk to...") + + // Find peers that are not connected right now + var peersNotInUse = ListBuffer[SourceInfo] () + synchronized { + peersNotInUse = listOfSources -- peersNowTalking + } + + peersNotInUse.foreach { eachSource => + var tempHasBlocksBitVector: BitSet = null + hasBlocksBitVector.synchronized { + tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + tempHasBlocksBitVector.flip (0, tempHasBlocksBitVector.size) + tempHasBlocksBitVector.and (eachSource.hasBlocksBitVector) + + if (tempHasBlocksBitVector.cardinality > curMax) { + curPeer = eachSource + curMax = tempHasBlocksBitVector.cardinality + } + } + + // Always pick randomly or randomly pick randomly? + // Now always picking randomly + if (curPeer == null && peersNotInUse.size > 0) { + // Pick uniformly the i'th required peer + var i = BitTorrentBroadcast.ranGen.nextInt (peersNotInUse.size) + + var peerIter = peersNotInUse.iterator + curPeer = peerIter.next + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + } + + if (curPeer != null) + logInfo ("Peer chosen: " + curPeer + " with " + curPeer.hasBlocksBitVector) + else + logInfo ("No peer chosen...") + + return curPeer + } + + class TalkToPeer (peerToTalkTo: SourceInfo) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + override def run: Unit = { + // TODO: There is a possible bug here regarding blocksInRequestBitVector + var blockToAskFor = -1 + + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule (timeOutTask, BitTorrentBroadcast.MaxKnockInterval) + + logInfo ("TalkToPeer started... => " + peerToTalkTo) + + try { + // Connect to the source + peerSocketToSource = + new Socket (peerToTalkTo.hostAddress, peerToTalkTo.listenPort) + oosSource = + new ObjectOutputStream (peerSocketToSource.getOutputStream) + oosSource.flush + oisSource = + new ObjectInputStream (peerSocketToSource.getInputStream) + + // Receive latest SourceInfo from peerToTalkTo + var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] + // Update listOfSources + addToListOfSources (newPeerToTalkTo) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush + + var keepReceiving = true + + while (hasBlocks < totalBlocks && keepReceiving) { + blockToAskFor = + pickBlockToRequest (newPeerToTalkTo.hasBlocksBitVector) + + // No block to request + if (blockToAskFor < 0) { + // Nothing to receive from newPeerToTalkTo + keepReceiving = false + } else { + // Let other thread know that blockToAskFor is being requested + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set (blockToAskFor) + } + + // Start with sending the blockID + oosSource.writeObject(blockToAskFor) + oosSource.flush + + // Receive the requested block + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + // Expecting sender to send the block that was asked for + assert (bcBlock.blockID == blockToAskFor) + + logInfo ("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") + + if (!hasBlocksBitVector.get(bcBlock.blockID)) { + arrayOfBlocks(bcBlock.blockID) = bcBlock + + // Update the hasBlocksBitVector first + hasBlocksBitVector.synchronized { + hasBlocksBitVector.set (bcBlock.blockID) + } + hasBlocks += 1 + + rxSpeeds.addDataPoint (peerToTalkTo, receptionTime) + + // blockToAskFor has arrived. Not in request any more + // Probably no need to update it though + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set (bcBlock.blockID, false) + } + + // Reset blockToAskFor to -1. Else it will be considered missing + blockToAskFor = -1 + } + + // Send the latest SourceInfo + oosSource.writeObject(getLocalSourceInfo) + oosSource.flush + } + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo ("TalktoPeer had a " + e) + // TODO: Remove 'newPeerToTalkTo' from listOfSources + // We probably should have the following in some form, but not + // really here. This exception can happen if the sender just breaks connection + // listOfSources.synchronized { + // logInfo ("Exception in TalkToPeer. Removing source: " + peerToTalkTo) + // listOfSources = listOfSources - peerToTalkTo + // } + } + } finally { + // blockToAskFor != -1 => there was an exception + if (blockToAskFor != -1) { + blocksInRequestBitVector.synchronized { + blocksInRequestBitVector.set (blockToAskFor, false) + } + } + + cleanUpConnections + } + } + + // Right now it picks a block uniformly that this peer does not have + // TODO: Implement more intelligent block selection policies + private def pickBlockToRequest (txHasBlocksBitVector: BitSet): Int = { + var needBlocksBitVector: BitSet = null + + // Blocks already present + hasBlocksBitVector.synchronized { + needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] + } + + // Include blocks already in transmission ONLY IF + // BitTorrentBroadcast.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks / totalBlocks) < BitTorrentBroadcast.EndGameFraction) { + blocksInRequestBitVector.synchronized { + needBlocksBitVector.or (blocksInRequestBitVector) + } + } + + // Find blocks that are neither here nor in transit + needBlocksBitVector.flip (0, needBlocksBitVector.size) + + // Blocks that should be requested + needBlocksBitVector.and (txHasBlocksBitVector) + + if (needBlocksBitVector.cardinality == 0) { + return -1 + } else { + // Pick uniformly the i'th required block + var i = BitTorrentBroadcast.ranGen.nextInt (needBlocksBitVector.cardinality) + var pickedBlockIndex = needBlocksBitVector.nextSetBit (0) + + while (i > 0) { + pickedBlockIndex = + needBlocksBitVector.nextSetBit (pickedBlockIndex + 1) + i = i - 1 + } + + return pickedBlockIndex + } + } + + private def cleanUpConnections: Unit = { + if (oisSource != null) { + oisSource.close + } + if (oosSource != null) { + oosSource.close + } + if (peerSocketToSource != null) { + peerSocketToSource.close + } + + // Delete from peersNowTalking + peersNowTalking.synchronized { + peersNowTalking = peersNowTalking - peerToTalkTo + } + } + } + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo] () + + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + guidePort = serverSocket.getLocalPort + logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // listOfSources.size - 1, because it includes the Guide itself + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo ("Guide: Accepted new client connection:" + clientSocket) + try { + threadPool.execute (new GuideSingleRequest (clientSocket)) + } catch { + // In failure, close the socket here; else, thread will close it + case ioe: IOException => { + clientSocket.close + } + } + } + } + + // Shutdown the thread pool + threadPool.shutdown + + logInfo ("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + BitTorrentBroadcast.unregisterValue (uuid) + } finally { + if (serverSocket != null) { + logInfo ("GuideMultipleRequests now stopping...") + serverSocket.close + } + } + } + + private def sendStopBroadcastNotifications: Unit = { + listOfSources.synchronized { + listOfSources.foreach { sourceInfo => + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream (guideSocketToSource.getOutputStream) + gosSource.flush + gisSource = + new ObjectInputStream (guideSocketToSource.getInputStream) + + // Throw away whatever comes in + gisSource.readObject.asInstanceOf[SourceInfo] + + // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast + gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast, + SourceInfo.UnusedParam, SourceInfo.UnusedParam)) + gosSource.flush + } catch { + case e: Exception => { + logInfo ("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close + } + if (gosSource != null) { + gosSource.close + } + if (guideSocketToSource != null) { + guideSocketToSource.close + } + } + } + } + } + + class GuideSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var sourceInfo: SourceInfo = null + private var selectedSources: ListBuffer[SourceInfo] = null + + override def run: Unit = { + try { + logInfo ("new GuideSingleRequest is running") + // Connecting worker is sending in its information + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Select a suitable source and send it back to the worker + selectedSources = selectSuitableSources (sourceInfo) + logInfo ("Sending selectedSources:" + selectedSources) + oos.writeObject (selectedSources) + oos.flush + + // Add this source to the listOfSources + addToListOfSources (sourceInfo) + } catch { + case e: Exception => { + // Assuming exception caused by receiver failure: remove + if (listOfSources != null) { + listOfSources.synchronized { + listOfSources = listOfSources - sourceInfo + } + } + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + + // Randomly select some sources to send back + private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { + var selectedSources = ListBuffer[SourceInfo] () + + // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' + // then add skipSourceInfo to setOfCompletedSources. Return blank. + if (skipSourceInfo.hasBlocks == totalBlocks) { + setOfCompletedSources += skipSourceInfo + return selectedSources + } + + listOfSources.synchronized { + if (listOfSources.size <= BitTorrentBroadcast.MaxPeersInGuideResponse) { + selectedSources = listOfSources.clone + } else { + var picksLeft = BitTorrentBroadcast.MaxPeersInGuideResponse + var alreadyPicked = new BitSet (listOfSources.size) + + while (picksLeft > 0) { + var i = -1 + + do { + i = BitTorrentBroadcast.ranGen.nextInt (listOfSources.size) + } while (alreadyPicked.get(i)) + + var peerIter = listOfSources.iterator + var curPeer = peerIter.next + + while (i > 0) { + curPeer = peerIter.next + i = i - 1 + } + + selectedSources = selectedSources + curPeer + alreadyPicked.set (i) + + picksLeft = picksLeft - 1 + } + } + } + + // Remove the receiving source (if present) + selectedSources = selectedSources - skipSourceInfo + + return selectedSources + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + // Server at most BitTorrentBroadcast.MaxRxPeers peers + var threadPool = + Broadcast.newDaemonFixedThreadPool(BitTorrentBroadcast.MaxRxPeers) + + override def run: Unit = { + var serverSocket = new ServerSocket (0) + listenPort = serverSocket.getLocalPort + + logInfo ("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (BitTorrentBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo ("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute (new ServeSingleRequest (clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo ("ServeMultipleRequests now stopping...") + serverSocket.close + } + } + // Shutdown the thread pool + threadPool.shutdown + } + + class ServeSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + logInfo ("new ServeSingleRequest is running") + + override def run: Unit = { + try { + // Send latest local SourceInfo to the receiver + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(getLocalSourceInfo) + oos.flush + + // Receive latest SourceInfo from the receiver + var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) + + if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + addToListOfSources (rxSourceInfo) + } + + val startTime = System.currentTimeMillis + var curTime = startTime + var keepSending = true + var numBlocksToSend = BitTorrentBroadcast.MaxChatBlocks + + while (!stopBroadcast && keepSending && numBlocksToSend > 0) { + // Receive which block to send + val blockToSend = ois.readObject.asInstanceOf[Int] + + // Send the block + sendBlock (blockToSend) + rxSourceInfo.hasBlocksBitVector.set (blockToSend) + + numBlocksToSend = numBlocksToSend - 1 + + // Receive latest SourceInfo from the receiver + rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] + // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) + addToListOfSources (rxSourceInfo) + + curTime = System.currentTimeMillis + // Revoke sending only if there is anyone waiting in the queue + if (curTime - startTime >= BitTorrentBroadcast.MaxChatTime && + threadPool.getQueue.size > 0) { + keepSending = false + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo ("ServeSingleRequest had a " + e) + } + } finally { + logInfo ("ServeSingleRequest is closing streams and sockets") + ois.close + // TODO: The following line causes a "java.net.SocketException: Socket closed" + oos.close + clientSocket.close + } + } + + private def sendBlock (blockToSend: Int): Unit = { + try { + oos.writeObject (arrayOfBlocks(blockToSend)) + oos.flush + } catch { + case e: Exception => { + logInfo ("sendBlock had a " + e) + } + } + logInfo ("Sent block: " + blockToSend + " to " + clientSocket) + } + } + } +} + +class BitTorrentBroadcastFactory +extends BroadcastFactory { + def initialize (isMaster: Boolean) = BitTorrentBroadcast.initialize (isMaster) + def newBroadcast[T] (value_ : T, isLocal: Boolean) = + new BitTorrentBroadcast[T] (value_, isLocal) +} + +private object BitTorrentBroadcast +extends Logging { + val values = Cache.newKeySpace() + + var valueToGuideMap = Map[UUID, SourceInfo] () + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress + private var MasterTrackerPort_ : Int = 11111 + private var BlockSize_ : Int = 512 * 1024 + private var MaxRetryCount_ : Int = 2 + + private var TrackerSocketTimeout_ : Int = 50000 + private var ServerSocketTimeout_ : Int = 10000 + + private var trackMV: TrackMultipleValues = null + + // A peer syncs back to Guide after waiting randomly within following limits + // Also used thoughout the code for small and large waits/timeouts + private var MinKnockInterval_ = 500 + private var MaxKnockInterval_ = 999 + + private var MaxPeersInGuideResponse_ = 4 + + // Maximum number of receiving and sending threads of a peer + private var MaxRxPeers_ = 4 + private var MaxTxPeers_ = 4 + + // Peers can char at most this milliseconds or transfer this number of blocks + private var MaxChatTime_ = 250 + private var MaxChatBlocks_ = 1024 + + // Fraction of blocks to receive before entering the end game + private var EndGameFraction_ = 1.0 + + + def initialize (isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + MasterTrackerPort_ = + System.getProperty ("spark.broadcast.masterTrackerPort", "11111").toInt + BlockSize_ = + System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 + MaxRetryCount_ = + System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt + + TrackerSocketTimeout_ = + System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt + ServerSocketTimeout_ = + System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt + + MinKnockInterval_ = + System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt + MaxKnockInterval_ = + System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt + + MaxPeersInGuideResponse_ = + System.getProperty ("spark.broadcast.maxPeersInGuideResponse", "4").toInt + + MaxRxPeers_ = + System.getProperty ("spark.broadcast.maxRxPeers", "4").toInt + MaxTxPeers_ = + System.getProperty ("spark.broadcast.maxTxPeers", "4").toInt + + MaxChatTime_ = + System.getProperty ("spark.broadcast.maxChatTime", "250").toInt + MaxChatBlocks_ = + System.getProperty ("spark.broadcast.maxChatBlocks", "1024").toInt + + EndGameFraction_ = + System.getProperty ("spark.broadcast.endGameFraction", "1.0").toDouble + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon (true) + trackMV.start + logInfo ("TrackMultipleValues started...") + } + + // Initialize DfsBroadcast to be used for broadcast variable persistence + DfsBroadcast.initialize + + initialized = true + } + } + } + + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def isMaster = isMaster_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ + + def MaxRxPeers = MaxRxPeers_ + def MaxTxPeers = MaxTxPeers_ + + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def EndGameFraction = EndGameFraction_ + + def registerValue (uuid: UUID, gInfo: SourceInfo): Unit = { + valueToGuideMap.synchronized { + valueToGuideMap += (uuid -> gInfo) + logInfo ("New value registered with the Tracker " + valueToGuideMap) + } + } + + def unregisterValue (uuid: UUID): Unit = { + valueToGuideMap.synchronized { + valueToGuideMap (uuid) = SourceInfo ("", SourceInfo.TxOverGoToHDFS, + SourceInfo.UnusedParam, SourceInfo.UnusedParam) + logInfo ("Value unregistered from the Tracker " + valueToGuideMap) + } + } + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (BitTorrentBroadcast.MasterTrackerPort) + logInfo ("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (TrackerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute (new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + val ois = new ObjectInputStream (clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + var gInfo = + if (valueToGuideMap.contains (uuid)) { + valueToGuideMap (uuid) + } else SourceInfo ("", SourceInfo.TxNotStartedRetry, + SourceInfo.UnusedParam, SourceInfo.UnusedParam) + logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) + oos.writeObject (gInfo) + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues had a " + e) + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => { + clientSocket.close + } + } + } + } + } finally { + serverSocket.close + } + // Shutdown the thread pool + threadPool.shutdown + } + } +} diff --git a/core/src/main/scala/spark/Broadcast.scala b/core/src/main/scala/spark/Broadcast.scala index 5089dca82e..fe2ab1ebf0 100644 --- a/core/src/main/scala/spark/Broadcast.scala +++ b/core/src/main/scala/spark/Broadcast.scala @@ -1,197 +1,107 @@ package spark -import java.io._ -import java.net._ -import java.util.{UUID, PriorityQueue, Comparator} - -import java.util.concurrent.{Executors, ExecutorService} - -import scala.actors.Actor -import scala.actors.Actor._ - -import scala.collection.mutable.Map - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - -import spark.compress.lzf.{LZFInputStream, LZFOutputStream} +import java.util.{BitSet, UUID} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} @serializable -trait BroadcastRecipe { +trait Broadcast[T] { val uuid = UUID.randomUUID + def value: T + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. Possibly a Scala bug! - def sendBroadcast: Unit override def toString = "spark.Broadcast(" + uuid + ")" } -// TODO: Right, now no parallelization between multiple broadcasts -@serializable -class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) -extends BroadcastRecipe with Logging { - - def value = value_ +trait BroadcastFactory { + def initialize (isMaster: Boolean): Unit + def newBroadcast[T] (value_ : T, isLocal: Boolean): Broadcast[T] +} - BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) } - - if (!local) { sendBroadcast } - - def sendBroadcast () { - // Create a variableInfo object and store it in valueInfos - var variableInfo = blockifyObject (value_, BroadcastCS.blockSize) - // TODO: Even though this part is not in use now, there is problem in the - // following statement. Shouldn't use constant port and hostAddress anymore? - // val masterSource = - // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort, - // variableInfo.totalBlocks, variableInfo.totalBytes, 0) - // variableInfo.pqOfSources.add (masterSource) - - BroadcastCS.synchronized { - // BroadcastCS.valueInfos.put (uuid, variableInfo) +private object Broadcast +extends Logging { + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + // Called by SparkContext or Executor before using Broadcast + def initialize (isMaster: Boolean): Unit = synchronized { + if (!initialized) { + val broadcastFactoryClass = System.getProperty("spark.broadcast.factory", + "spark.DfsBroadcastFactory") + val booleanArgs = Array[AnyRef] (isMaster.asInstanceOf[AnyRef]) + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - // TODO: Not using variableInfo in current implementation. Manually - // setting all the variables inside BroadcastCS object + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isMaster) - BroadcastCS.initializeVariable (variableInfo) + initialized = true } - - // Now store a persistent copy in HDFS, just in case - val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - out.writeObject (value_) - out.close } - // Called by Java when deserializing an object - private def readObject (in: ObjectInputStream) { - in.defaultReadObject - BroadcastCS.synchronized { - val cachedVal = BroadcastCS.values.get (uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Only a single worker (the first one) in the same node can ever be - // here. The rest will always get the value ready. - val start = System.nanoTime - - val retByteArray = BroadcastCS.receiveBroadcast (uuid) - // If does not succeed, then get from HDFS copy - if (retByteArray != null) { - value_ = byteArrayToObject[T] (retByteArray) - BroadcastCS.values.put (uuid, value_) - // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize) - // BroadcastCS.valueInfos.put (uuid, variableInfo) - } else { - val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BroadcastCH.values.put(uuid, value_) - fileIn.close - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } + def getBroadcastFactory: BroadcastFactory = { + if (broadcastFactory == null) { + throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") } + broadcastFactory } - private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream (baos) - oos.writeObject (obj) - oos.close - baos.close - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream (byteArray) - - var blockNum = (byteArray.length / blockSize) - if (byteArray.length % blockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock] (blockNum) - var blockID = 0 - - // TODO: What happens in byteArray.length == 0 => blockNum == 0 - for (i <- 0 until (byteArray.length, blockSize)) { - val thisBlockSize = Math.min (blockSize, byteArray.length - i) - var tempByteArray = new Array[Byte] (thisBlockSize) - val hasRead = bais.read (tempByteArray, 0, thisBlockSize) - - retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) - blockID += 1 - } - bais.close - - var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - private def byteArrayToObject[A] (bytes: Array[Byte]): A = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) - val retVal = in.readObject.asInstanceOf[A] - in.close - return retVal + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread (r) + t.setDaemon (true) + return t + } + } } - private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = { - val bOut = new ByteArrayOutputStream - val out = new ObjectOutputStream (bOut) - out.writeObject (obj) - out.close - bOut.close - return bOut - } -} - -@serializable -class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) -extends BroadcastRecipe with Logging { + // Wrapper over newCachedThreadPool + def newDaemonCachedThreadPool: ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - def value = value_ - - BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) } - - if (!local) { sendBroadcast } - - def sendBroadcast () { - val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - out.writeObject (value_) - out.close - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject - BroadcastCH.synchronized { - val cachedVal = BroadcastCH.values.get(uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - val start = System.nanoTime - - val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BroadcastCH.values.put(uuid, value_) - fileIn.close - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } } @serializable case class SourceInfo (val hostAddress: String, val listenPort: Int, - val totalBlocks: Int, val totalBytes: Int, val replicaID: Int) -extends Comparable[SourceInfo]{ + val totalBlocks: Int, val totalBytes: Int) +extends Comparable[SourceInfo] with Logging { var currentLeechers = 0 var receptionFailed = false - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) + var hasBlocks = 0 + var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) + + // Ascending sort based on leecher count + def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)} + +object SourceInfo { + // Constants for special values of listenPort + val TxNotStartedRetry = -1 + val TxOverGoToHDFS = 0 + // Other constants + val StopBroadcast = -2 + val UnusedParam = 0 } @serializable @@ -199,601 +109,32 @@ case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } @serializable case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], - val totalBlocks: Int, val totalBytes: Int) { - + val totalBlocks: Int, val totalBytes: Int) { @transient var hasBlocks = 0 - - val listenPortLock = new AnyRef - val totalBlocksLock = new AnyRef - val hasBlocksLock = new AnyRef - - @transient var pqOfSources = new PriorityQueue[SourceInfo] } -private object Broadcast { - private var initialized = false - - // Will be called by SparkContext or Executor before using Broadcast - // Calls all other initializers here - def initialize (isMaster: Boolean) { - synchronized { - if (!initialized) { - // Initialization for CentralizedHDFSBroadcast - BroadcastCH.initialize - // Initialization for ChainedStreamingBroadcast - // BroadcastCS.initialize (isMaster) - - initialized = true - } - } - } -} - -private object BroadcastCS extends Logging { - val values = Cache.newKeySpace() - - // private var valueToPort = Map[UUID, Int] () - - private var initialized = false - private var isMaster_ = false - - private var masterHostAddress_ = "127.0.0.1" - private var masterListenPort_ : Int = 11111 - private var blockSize_ : Int = 512 * 1024 - private var maxRetryCount_ : Int = 2 - private var serverSocketTimout_ : Int = 50000 - private var dualMode_ : Boolean = false - - private val hostAddress = InetAddress.getLocalHost.getHostAddress - private var listenPort = -1 - - var arrayOfBlocks: Array[BroadcastBlock] = null - var totalBytes = -1 - var totalBlocks = -1 - var hasBlocks = 0 - - val listenPortLock = new Object - val totalBlocksLock = new Object - val hasBlocksLock = new Object - - var pqOfSources = new PriorityQueue[SourceInfo] - - private var serveMR: ServeMultipleRequests = null - private var guideMR: GuideMultipleRequests = null - - def initialize (isMaster__ : Boolean) { - synchronized { - if (!initialized) { - masterHostAddress_ = - System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1") - masterListenPort_ = - System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt - blockSize_ = - System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 - maxRetryCount_ = - System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt - serverSocketTimout_ = - System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt - dualMode_ = - System.getProperty ("spark.broadcast.dualMode", "false").toBoolean - - isMaster_ = isMaster__ - - if (isMaster) { - guideMR = new GuideMultipleRequests - guideMR.setDaemon (true) - guideMR.start - logInfo("GuideMultipleRequests started") - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon (true) - serveMR.start - logInfo("ServeMultipleRequests started") - - logInfo("BroadcastCS object has been initialized") - - initialized = true +@serializable +class SpeedTracker { + // Mapping 'source' to '(totalTime, numBlocks)' + private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] () + + def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long): Unit = { + sourceToSpeedMap.synchronized { + if (!sourceToSpeedMap.contains(srcInfo)) { + sourceToSpeedMap += (srcInfo -> (timeInMillis, 1)) + } else { + val tTnB = sourceToSpeedMap (srcInfo) + sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1)) } } } - // TODO: This should change in future implementation. - // Called from the Master constructor to setup states for this particular that - // is being broadcasted - def initializeVariable (variableInfo: VariableInfo) { - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - // listenPort should already be valid - assert (listenPort != -1) - - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource_0 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) - BroadcastCS.pqOfSources.add (masterSource_0) - // Add one more time to have two replicas of any seeds in the PQ - if (BroadcastCS.dualMode) { - val masterSource_1 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) - BroadcastCS.pqOfSources.add (masterSource_1) - } - } - - def masterHostAddress = masterHostAddress_ - def masterListenPort = masterListenPort_ - def blockSize = blockSize_ - def maxRetryCount = maxRetryCount_ - def serverSocketTimout = serverSocketTimout_ - def dualMode = dualMode_ - - def isMaster = isMaster_ - - def receiveBroadcast (variableUUID: UUID): Array[Byte] = { - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - // NO need to wait; ServeMultipleRequests is created much further ahead - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait - } - } - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = BroadcastCS.maxRetryCount - var retByteArray: Array[Byte] = null - do { - // Connect to Master and send this worker's Information - val clientSocketToMaster = - new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort) - logInfo("Connected to Master's guiding object") - // TODO: Guiding object connection is reusable - val oisMaster = - new ObjectInputStream (clientSocketToMaster.getInputStream) - val oosMaster = - new ObjectOutputStream (clientSocketToMaster.getOutputStream) - - oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0)) - oosMaster.flush - - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll - } - totalBytes = sourceInfo.totalBytes - logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) - - retByteArray = receiveSingleTransmission (sourceInfo) - - logInfo("I got this from receiveSingleTransmission: " + retByteArray) - - // TODO: Update sourceInfo to add error notifactions for Master - if (retByteArray == null) { sourceInfo.receptionFailed = true } - - // TODO: Supposed to update values here, but we don't support advanced - // statistics right now. Master can handle leecherCount by itself. - - // Send back statistics to the Master - oosMaster.writeObject (sourceInfo) - - oisMaster.close - oosMaster.close - clientSocketToMaster.close - - retriesLeft -= 1 - } while (retriesLeft > 0 && retByteArray == null) - - return retByteArray - } - - // Tries to receive broadcast from the Master and returns Boolean status. - // This might be called multiple times to retry a defined number of times. - private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = { - var clientSocketToSource: Socket = null - var oisSource: ObjectInputStream = null - var oosSource: ObjectOutputStream = null - - var retByteArray:Array[Byte] = null - - try { - // Connect to the source to get the object itself - clientSocketToSource = - new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream (clientSocketToSource.getOutputStream) - oisSource = - new ObjectInputStream (clientSocketToSource.getInputStream) - - logInfo("Inside receiveSingleTransmission") - logInfo("totalBlocks: " + totalBlocks + " " + "hasBlocks: " + hasBlocks) - retByteArray = new Array[Byte] (totalBytes) - for (i <- 0 until totalBlocks) { - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - System.arraycopy (bcBlock.byteArray, 0, retByteArray, - i * BroadcastCS.blockSize, bcBlock.byteArray.length) - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll - } - logInfo("Received block: " + i + " " + bcBlock) - } - assert (hasBlocks == totalBlocks) - logInfo("After the receive loop") - } catch { - case e: Exception => { - retByteArray = null - logInfo("receiveSingleTransmission had a " + e) - } - } finally { - if (oisSource != null) { oisSource.close } - if (oosSource != null) { oosSource.close } - if (clientSocketToSource != null) { clientSocketToSource.close } - } - - return retByteArray - } - -// class TrackMultipleValues extends Thread with Logging { -// override def run = { -// var threadPool = Executors.newCachedThreadPool -// var serverSocket: ServerSocket = null -// -// serverSocket = new ServerSocket (BroadcastCS.masterListenPort) -// logInfo("TrackMultipleVariables" + serverSocket + " " + listenPort) -// -// var keepAccepting = true -// try { -// while (keepAccepting) { -// var clientSocket: Socket = null -// try { -// serverSocket.setSoTimeout (serverSocketTimout) -// clientSocket = serverSocket.accept -// } catch { -// case e: Exception => { -// logInfo("TrackMultipleValues Timeout. Stopping listening...") -// keepAccepting = false -// } -// } -// logInfo("TrackMultipleValues:Got new request:" + clientSocket) -// if (clientSocket != null) { -// try { -// threadPool.execute (new Runnable { -// def run = { -// val oos = new ObjectOutputStream (clientSocket.getOutputStream) -// val ois = new ObjectInputStream (clientSocket.getInputStream) -// try { -// val variableUUID = ois.readObject.asInstanceOf[UUID] -// var contactPort = 0 -// // TODO: Add logic and data structures to find out UUID->port -// // mapping. 0 = missed the broadcast, read from HDFS; <0 = -// // Haven't started yet, wait & retry; >0 = Read from this port -// oos.writeObject (contactPort) -// } catch { -// case e: Exception => { } -// } finally { -// ois.close -// oos.close -// clientSocket.close -// } -// } -// }) -// } catch { -// // In failure, close the socket here; else, the thread will close it -// case ioe: IOException => clientSocket.close -// } -// } -// } -// } finally { -// serverSocket.close -// } -// } -// } -// -// class TrackSingleValue { -// -// } - -// public static ExecutorService newCachedThreadPool() { -// return new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60L, TimeUnit.SECONDS, -// new SynchronousQueue<Runnable>()); -// } - - - class GuideMultipleRequests extends Thread with Logging { - override def run = { - var threadPool = Executors.newCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (BroadcastCS.masterListenPort) - // listenPort = BroadcastCS.masterListenPort - logInfo("GuideMultipleRequests" + serverSocket + " " + listenPort) - - var keepAccepting = true - try { - while (keepAccepting) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (serverSocketTimout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("GuideMultipleRequests Timeout. Stopping listening...") - keepAccepting = false - } - } - if (clientSocket != null) { - logInfo("Guide:Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new GuideSingleRequest (clientSocket)) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - } - - class GuideSingleRequest (val clientSocket: Socket) - extends Runnable with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - def run = { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. ReplicaID is 0 and other fields are invalid (-1) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource (sourceInfo) - logInfo("Sending selectedSourceInfo:" + selectedSourceInfo) - oos.writeObject (selectedSourceInfo) - oos.flush - - // Add this new (if it can finish) source to the PQ of sources - thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, 0) - logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo) - pqOfSources.synchronized { - pqOfSources.add (thisWorkerInfo) - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in pqOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // This should work since SourceInfo is a case class - assert (pqOfSources.contains (selectedSourceInfo)) - - // Remove first - pqOfSources.remove (selectedSourceInfo) - // TODO: Removing a source based on just one failure notification! - // Update leecher count and put it back in IF reception succeeded - if (!sourceInfo.receptionFailed) { - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add (selectedSourceInfo) - - // No need to find and update thisWorkerInfo, but add its replica - if (BroadcastCS.dualMode) { - pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress, - thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1)) - } - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - // Assuming that exception caused due to receiver worker failure - // Remove failed worker from pqOfSources and update leecherCount of - // corresponding source worker - pqOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - pqOfSources.remove (selectedSourceInfo) - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add (selectedSourceInfo) - } - - // Remove thisWorkerInfo - if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) } - } - } - } finally { - ois.close - oos.close - clientSocket.close - } - } - - // TODO: If a worker fails to get the broadcasted variable from a source and - // comes back to Master, this function might choose the worker itself as a - // source tp create a dependency cycle (this worker was put into pqOfSources - // as a streming source when it first arrived). The length of this cycle can - // be arbitrarily long. - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one with the lowest number of leechers - pqOfSources.synchronized { - // take is a blocking call removing the element from PQ - var selectedSource = pqOfSources.poll - assert (selectedSource != null) - // Update leecher count - selectedSource.currentLeechers += 1 - // Add it back and then return - pqOfSources.add (selectedSource) - return selectedSource - } - } - } - } - - class ServeMultipleRequests extends Thread with Logging { - override def run = { - var threadPool = Executors.newCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (0) - listenPort = serverSocket.getLocalPort - logInfo("ServeMultipleRequests" + serverSocket + " " + listenPort) - - listenPortLock.synchronized { - listenPortLock.notifyAll - } - - var keepAccepting = true - try { - while (keepAccepting) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (serverSocketTimout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout. Stopping listening...") - keepAccepting = false - } - } - if (clientSocket != null) { - logInfo("Serve:Accepted new client connection:" + clientSocket) - try { - threadPool.execute (new ServeSingleRequest (clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - } - - class ServeSingleRequest (val clientSocket: Socket) - extends Runnable with Logging { - private val oos = new ObjectOutputStream (clientSocket.getOutputStream) - private val ois = new ObjectInputStream (clientSocket.getInputStream) - - def run = { - try { - logInfo("new ServeSingleRequest is running") - sendObject - } catch { - // TODO: Need to add better exception handling here - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close - oos.close - clientSocket.close - } - } - - private def sendObject = { - // Wait till receiving the SourceInfo from Master - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait - } - } - - for (i <- 0 until totalBlocks) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait - } - } - try { - oos.writeObject (arrayOfBlocks(i)) - oos.flush - } catch { - case e: Exception => { } - } - logInfo("Send block: " + i + " " + arrayOfBlocks(i)) - } - } + def getTimePerBlock (srcInfo: SourceInfo): Double = { + sourceToSpeedMap.synchronized { + val tTnB = sourceToSpeedMap (srcInfo) + return tTnB._1 / tTnB._2 } } -} - -private object BroadcastCH extends Logging { - val values = Cache.newKeySpace() - - private var initialized = false - - private var fileSystem: FileSystem = null - private var workDir: String = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - - def initialize () { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val dfs = System.getProperty("spark.dfs", "file:///") - if (!dfs.startsWith("file://")) { - val conf = new Configuration() - conf.setInt("io.file.buffer.size", bufferSize) - val rep = System.getProperty("spark.dfs.replication", "3").toInt - conf.setInt("dfs.replication", rep) - fileSystem = FileSystem.get(new URI(dfs), conf) - } - workDir = System.getProperty("spark.dfs.workdir", "/tmp") - compress = System.getProperty("spark.compress", "false").toBoolean - - initialized = true - } - } - } - - private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) - - def openFileForReading(uuid: UUID): InputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.open(getPath(uuid)) - } else { - // Local filesystem - new FileInputStream(getPath(uuid).toString) - } - if (compress) - new LZFInputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedInputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } - - def openFileForWriting(uuid: UUID): OutputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.create(getPath(uuid)) - } else { - // Local filesystem - new FileOutputStream(getPath(uuid).toString) - } - if (compress) - new LZFOutputStream(fileStream) // LZF stream does its own buffering - else if (fileSystem == null) - new BufferedOutputStream(fileStream, bufferSize) - else - fileStream // Hadoop streams do their own buffering - } + + override def toString = sourceToSpeedMap.toString } diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala new file mode 100644 index 0000000000..afd3c0293c --- /dev/null +++ b/core/src/main/scala/spark/ChainedBroadcast.scala @@ -0,0 +1,870 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{Comparator, PriorityQueue, Random, UUID} + +import scala.collection.mutable.{Map, Set} + +@serializable +class ChainedBroadcast[T] (@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + ChainedBroadcast.synchronized { + ChainedBroadcast.values.put (uuid, value_) + } + + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var pqOfSources = new PriorityQueue[SourceInfo] + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = InetAddress.getLocalHost.getHostAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + + @transient var hasCopyInHDFS = false + @transient var stopBroadcast = false + + // Must call this after all the variables have been created/initialized + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast (): Unit = { + logInfo ("Local host address: " + hostAddress) + + // Store a persistent copy in HDFS + // TODO: Turned OFF for now + // val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) + // out.writeObject (value_) + // out.close + // TODO: Fix this at some point + hasCopyInHDFS = true + + // Create a variableInfo object and store it in valueInfos + var variableInfo = blockifyObject (value_, ChainedBroadcast.BlockSize) + + guideMR = new GuideMultipleRequests + guideMR.setDaemon (true) + guideMR.start + logInfo ("GuideMultipleRequests started...") + + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + pqOfSources = new PriorityQueue[SourceInfo] + val masterSource_0 = + SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes) + pqOfSources.add (masterSource_0) + + // Register with the Tracker + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + ChainedBroadcast.registerValue (uuid, guidePort) + } + + private def readObject (in: ObjectInputStream): Unit = { + in.defaultReadObject + ChainedBroadcast.synchronized { + val cachedVal = ChainedBroadcast.values.get (uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // Initializing everything because Master will only send null/0 values + initializeSlaveVariables + + logInfo ("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon (true) + serveMR.start + logInfo ("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast (uuid) + // If does not succeed, then get from HDFS copy + if (receptionSucceeded) { + value_ = unBlockifyObject[T] + ChainedBroadcast.values.put (uuid, value_) + } else { + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + ChainedBroadcast.values.put(uuid, value_) + fileIn.close + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } + + private def initializeSlaveVariables: Unit = { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + + serveMR = null + + hostAddress = InetAddress.getLocalHost.getHostAddress + listenPort = -1 + + stopBroadcast = false + } + + private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream (baos) + oos.writeObject (obj) + oos.close + baos.close + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream (byteArray) + + var blockNum = (byteArray.length / blockSize) + if (byteArray.length % blockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock] (blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, blockSize)) { + val thisBlockSize = Math.min (blockSize, byteArray.length - i) + var tempByteArray = new Array[Byte] (thisBlockSize) + val hasRead = bais.read (tempByteArray, 0, thisBlockSize) + + retVal (blockID) = new BroadcastBlock (blockID, tempByteArray) + blockID += 1 + } + bais.close + + var variableInfo = VariableInfo (retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + private def unBlockifyObject[A]: A = { + var retByteArray = new Array[Byte] (totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy (arrayOfBlocks(i).byteArray, 0, retByteArray, + i * ChainedBroadcast.BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject (retByteArray) + } + + private def byteArrayToObject[A] (bytes: Array[Byte]): A = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) + val retVal = in.readObject.asInstanceOf[A] + in.close + return retVal + } + + def getMasterListenPort (variableUUID: UUID): Int = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var masterListenPort: Int = SourceInfo.TxOverGoToHDFS + + var retriesLeft = ChainedBroadcast.MaxRetryCount + do { + try { + // Connect to the tracker to find out the guide + val clientSocketToTracker = + new Socket(ChainedBroadcast.MasterHostAddress, ChainedBroadcast.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream (clientSocketToTracker.getOutputStream) + oosTracker.flush + val oisTracker = + new ObjectInputStream (clientSocketToTracker.getInputStream) + + // Send UUID and receive masterListenPort + oosTracker.writeObject (uuid) + oosTracker.flush + masterListenPort = oisTracker.readObject.asInstanceOf[Int] + } catch { + case e: Exception => { + logInfo ("getMasterListenPort had a " + e) + } + } finally { + if (oisTracker != null) { + oisTracker.close + } + if (oosTracker != null) { + oosTracker.close + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close + } + } + retriesLeft -= 1 + + Thread.sleep (ChainedBroadcast.ranGen.nextInt ( + ChainedBroadcast.MaxKnockInterval - ChainedBroadcast.MinKnockInterval) + + ChainedBroadcast.MinKnockInterval) + + } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) + + logInfo ("Got this guidePort from Tracker: " + masterListenPort) + return masterListenPort + } + + def receiveBroadcast (variableUUID: UUID): Boolean = { + val masterListenPort = getMasterListenPort (variableUUID) + + if (masterListenPort == SourceInfo.TxOverGoToHDFS || + masterListenPort == SourceInfo.TxNotStartedRetry) { + // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go + // to HDFS anyway when receiveBroadcast returns false + return false + } + + // Wait until hostAddress and listenPort are created by the + // ServeMultipleRequests thread + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + var clientSocketToMaster: Socket = null + var oosMaster: ObjectOutputStream = null + var oisMaster: ObjectInputStream = null + + // Connect and receive broadcast from the specified source, retrying the + // specified number of times in case of failures + var retriesLeft = ChainedBroadcast.MaxRetryCount + do { + // Connect to Master and send this worker's Information + clientSocketToMaster = + new Socket(ChainedBroadcast.MasterHostAddress, masterListenPort) + // TODO: Guiding object connection is reusable + oosMaster = + new ObjectOutputStream (clientSocketToMaster.getOutputStream) + oosMaster.flush + oisMaster = + new ObjectInputStream (clientSocketToMaster.getInputStream) + + logInfo ("Connected to Master's guiding object") + + // Send local source information + oosMaster.writeObject(SourceInfo (hostAddress, listenPort, + SourceInfo.UnusedParam, SourceInfo.UnusedParam)) + oosMaster.flush + + // Receive source information from Master + var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + totalBlocks = sourceInfo.totalBlocks + arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks) + totalBlocksLock.synchronized { + totalBlocksLock.notifyAll + } + totalBytes = sourceInfo.totalBytes + + logInfo ("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + + val start = System.nanoTime + val receptionSucceeded = receiveSingleTransmission (sourceInfo) + val time = (System.nanoTime - start) / 1e9 + + // Updating some statistics in sourceInfo. Master will be using them later + if (!receptionSucceeded) { + sourceInfo.receptionFailed = true + } + + // Send back statistics to the Master + oosMaster.writeObject (sourceInfo) + + if (oisMaster != null) { + oisMaster.close + } + if (oosMaster != null) { + oosMaster.close + } + if (clientSocketToMaster != null) { + clientSocketToMaster.close + } + + retriesLeft -= 1 + } while (retriesLeft > 0 && hasBlocks < totalBlocks) + + return (hasBlocks == totalBlocks) + } + + // Tries to receive broadcast from the source and returns Boolean status. + // This might be called multiple times to retry a defined number of times. + private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { + var clientSocketToSource: Socket = null + var oosSource: ObjectOutputStream = null + var oisSource: ObjectInputStream = null + + var receptionSucceeded = false + try { + // Connect to the source to get the object itself + clientSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = + new ObjectOutputStream (clientSocketToSource.getOutputStream) + oosSource.flush + oisSource = + new ObjectInputStream (clientSocketToSource.getInputStream) + + logInfo ("Inside receiveSingleTransmission") + logInfo ("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + + // Send the range + oosSource.writeObject((hasBlocks, totalBlocks)) + oosSource.flush + + for (i <- hasBlocks until totalBlocks) { + val recvStartTime = System.currentTimeMillis + val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] + val receptionTime = (System.currentTimeMillis - recvStartTime) + + logInfo ("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + + arrayOfBlocks(hasBlocks) = bcBlock + hasBlocks += 1 + // Set to true if at least one block is received + receptionSucceeded = true + hasBlocksLock.synchronized { + hasBlocksLock.notifyAll + } + } + } catch { + case e: Exception => { + logInfo ("receiveSingleTransmission had a " + e) + } + } finally { + if (oisSource != null) { + oisSource.close + } + if (oosSource != null) { + oosSource.close + } + if (clientSocketToSource != null) { + clientSocketToSource.close + } + } + + return receptionSucceeded + } + + class GuideMultipleRequests + extends Thread with Logging { + // Keep track of sources that have completed reception + private var setOfCompletedSources = Set[SourceInfo] () + + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + guidePort = serverSocket.getLocalPort + logInfo ("GuideMultipleRequests => " + serverSocket + " " + guidePort) + + guidePortLock.synchronized { + guidePortLock.notifyAll + } + + try { + // Don't stop until there is a copy in HDFS + while (!stopBroadcast || !hasCopyInHDFS) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("GuideMultipleRequests Timeout.") + + // Stop broadcast if at least one worker has connected and + // everyone connected so far are done. Comparing with + // pqOfSources.size - 1, because it includes the Guide itself + if (pqOfSources.size > 1 && + setOfCompletedSources.size == pqOfSources.size - 1) { + stopBroadcast = true + } + } + } + if (clientSocket != null) { + logInfo ("Guide: Accepted new client connection: " + clientSocket) + try { + threadPool.execute (new GuideSingleRequest (clientSocket)) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + + logInfo ("Sending stopBroadcast notifications...") + sendStopBroadcastNotifications + + ChainedBroadcast.unregisterValue (uuid) + } finally { + if (serverSocket != null) { + logInfo ("GuideMultipleRequests now stopping...") + serverSocket.close + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + private def sendStopBroadcastNotifications: Unit = { + pqOfSources.synchronized { + var pqIter = pqOfSources.iterator + while (pqIter.hasNext) { + var sourceInfo = pqIter.next + + var guideSocketToSource: Socket = null + var gosSource: ObjectOutputStream = null + var gisSource: ObjectInputStream = null + + try { + // Connect to the source + guideSocketToSource = + new Socket (sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = + new ObjectOutputStream (guideSocketToSource.getOutputStream) + gosSource.flush + gisSource = + new ObjectInputStream (guideSocketToSource.getInputStream) + + // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 + gosSource.writeObject ((SourceInfo.StopBroadcast, + SourceInfo.StopBroadcast)) + gosSource.flush + } catch { + case e: Exception => { + logInfo ("sendStopBroadcastNotifications had a " + e) + } + } finally { + if (gisSource != null) { + gisSource.close + } + if (gosSource != null) { + gosSource.close + } + if (guideSocketToSource != null) { + guideSocketToSource.close + } + } + } + } + } + + class GuideSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var selectedSourceInfo: SourceInfo = null + private var thisWorkerInfo:SourceInfo = null + + override def run: Unit = { + try { + logInfo ("new GuideSingleRequest is running") + // Connecting worker is sending in its hostAddress and listenPort it will + // be listening to. Other fields are invalid (SourceInfo.UnusedParam) + var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // Select a suitable source and send it back to the worker + selectedSourceInfo = selectSuitableSource (sourceInfo) + logInfo ("Sending selectedSourceInfo: " + selectedSourceInfo) + oos.writeObject (selectedSourceInfo) + oos.flush + + // Add this new (if it can finish) source to the PQ of sources + thisWorkerInfo = SourceInfo (sourceInfo.hostAddress, + sourceInfo.listenPort, totalBlocks, totalBytes) + logInfo ("Adding possible new source to pqOfSources: " + thisWorkerInfo) + pqOfSources.add (thisWorkerInfo) + } + + // Wait till the whole transfer is done. Then receive and update source + // statistics in pqOfSources + sourceInfo = ois.readObject.asInstanceOf[SourceInfo] + + pqOfSources.synchronized { + // This should work since SourceInfo is a case class + assert (pqOfSources.contains (selectedSourceInfo)) + + // Remove first + pqOfSources.remove (selectedSourceInfo) + // TODO: Removing a source based on just one failure notification! + + // Update sourceInfo and put it back in, IF reception succeeded + if (!sourceInfo.receptionFailed) { + // Add thisWorkerInfo to sources that have completed reception + setOfCompletedSources += thisWorkerInfo + + selectedSourceInfo.currentLeechers -= 1 + + // Put it back + pqOfSources.add (selectedSourceInfo) + } + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + // Assuming that exception caused due to receiver worker failure. + // Remove failed worker from pqOfSources and update leecherCount of + // corresponding source worker + pqOfSources.synchronized { + if (selectedSourceInfo != null) { + // Remove first + pqOfSources.remove (selectedSourceInfo) + // Update leecher count and put it back in + selectedSourceInfo.currentLeechers -= 1 + pqOfSources.add (selectedSourceInfo) + } + + // Remove thisWorkerInfo + if (pqOfSources != null) { + pqOfSources.remove (thisWorkerInfo) + } + } + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + + // TODO: Caller must have a synchronized block on pqOfSources + // TODO: If a worker fails to get the broadcasted variable from a source and + // comes back to Master, this function might choose the worker itself as a + // source tp create a dependency cycle (this worker was put into pqOfSources + // as a streming source when it first arrived). The length of this cycle can + // be arbitrarily long. + private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { + // Select one based on the ordering strategy (e.g., least leechers etc.) + // take is a blocking call removing the element from PQ + var selectedSource = pqOfSources.poll + assert (selectedSource != null) + // Update leecher count + selectedSource.currentLeechers += 1 + // Add it back and then return + pqOfSources.add (selectedSource) + return selectedSource + } + } + } + + class ServeMultipleRequests + extends Thread with Logging { + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (0) + listenPort = serverSocket.getLocalPort + logInfo ("ServeMultipleRequests started with " + serverSocket) + + listenPortLock.synchronized { + listenPortLock.notifyAll + } + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (ChainedBroadcast.ServerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("ServeMultipleRequests Timeout.") + } + } + if (clientSocket != null) { + logInfo ("Serve: Accepted new client connection: " + clientSocket) + try { + threadPool.execute (new ServeSingleRequest (clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + if (serverSocket != null) { + logInfo ("ServeMultipleRequests now stopping...") + serverSocket.close + } + } + + // Shutdown the thread pool + threadPool.shutdown + } + + class ServeSingleRequest (val clientSocket: Socket) + extends Thread with Logging { + private val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + private var sendFrom = 0 + private var sendUntil = totalBlocks + + override def run: Unit = { + try { + logInfo ("new ServeSingleRequest is running") + + // Receive range to send + var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] + sendFrom = rangeToSend._1 + sendUntil = rangeToSend._2 + + if (sendFrom == SourceInfo.StopBroadcast && + sendUntil == SourceInfo.StopBroadcast) { + stopBroadcast = true + } else { + // Carry on + sendObject + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + case e: Exception => { + logInfo ("ServeSingleRequest had a " + e) + } + } finally { + logInfo ("ServeSingleRequest is closing streams and sockets") + ois.close + oos.close + clientSocket.close + } + } + + private def sendObject: Unit = { + // Wait till receiving the SourceInfo from Master + while (totalBlocks == -1) { + totalBlocksLock.synchronized { + totalBlocksLock.wait + } + } + + for (i <- sendFrom until sendUntil) { + while (i == hasBlocks) { + hasBlocksLock.synchronized { + hasBlocksLock.wait + } + } + try { + oos.writeObject (arrayOfBlocks(i)) + oos.flush + } catch { + case e: Exception => { + logInfo ("sendObject had a " + e) + } + } + logInfo ("Sent block: " + i + " to " + clientSocket) + } + } + } + } +} + +class ChainedBroadcastFactory +extends BroadcastFactory { + def initialize (isMaster: Boolean) = ChainedBroadcast.initialize (isMaster) + def newBroadcast[T] (value_ : T, isLocal: Boolean) = + new ChainedBroadcast[T] (value_, isLocal) +} + +private object ChainedBroadcast +extends Logging { + val values = Cache.newKeySpace() + + var valueToGuidePortMap = Map[UUID, Int] () + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var MasterHostAddress_ = InetAddress.getLocalHost.getHostAddress + private var MasterTrackerPort_ : Int = 22222 + private var BlockSize_ : Int = 512 * 1024 + private var MaxRetryCount_ : Int = 2 + + private var TrackerSocketTimeout_ : Int = 50000 + private var ServerSocketTimeout_ : Int = 10000 + + private var trackMV: TrackMultipleValues = null + + private var MinKnockInterval_ = 500 + private var MaxKnockInterval_ = 999 + + def initialize (isMaster__ : Boolean): Unit = { + synchronized { + if (!initialized) { + MasterTrackerPort_ = + System.getProperty ("spark.broadcast.masterTrackerPort", "22222").toInt + BlockSize_ = + System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 + MaxRetryCount_ = + System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt + + TrackerSocketTimeout_ = + System.getProperty ("spark.broadcast.trackerSocketTimeout", "50000").toInt + ServerSocketTimeout_ = + System.getProperty ("spark.broadcast.serverSocketTimeout", "10000").toInt + + MinKnockInterval_ = + System.getProperty ("spark.broadcast.minKnockInterval", "500").toInt + MaxKnockInterval_ = + System.getProperty ("spark.broadcast.maxKnockInterval", "999").toInt + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon (true) + trackMV.start + logInfo ("TrackMultipleValues started...") + } + + // Initialize DfsBroadcast to be used for broadcast variable persistence + DfsBroadcast.initialize + + initialized = true + } + } + } + + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def isMaster = isMaster_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def registerValue (uuid: UUID, guidePort: Int): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap += (uuid -> guidePort) + logInfo ("New value registered with the Tracker " + valueToGuidePortMap) + } + } + + def unregisterValue (uuid: UUID): Unit = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap (uuid) = SourceInfo.TxOverGoToHDFS + logInfo ("Value unregistered from the Tracker " + valueToGuidePortMap) + } + } + + class TrackMultipleValues + extends Thread with Logging { + override def run: Unit = { + var threadPool = Broadcast.newDaemonCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (ChainedBroadcast.MasterTrackerPort) + logInfo ("TrackMultipleValues" + serverSocket) + + try { + while (true) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout (TrackerSocketTimeout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues Timeout. Stopping listening...") + } + } + + if (clientSocket != null) { + try { + threadPool.execute (new Thread { + override def run: Unit = { + val oos = new ObjectOutputStream (clientSocket.getOutputStream) + oos.flush + val ois = new ObjectInputStream (clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + var guidePort = + if (valueToGuidePortMap.contains (uuid)) { + valueToGuidePortMap (uuid) + } else SourceInfo.TxNotStartedRetry + logInfo ("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) + oos.writeObject (guidePort) + } catch { + case e: Exception => { + logInfo ("TrackMultipleValues had a " + e) + } + } finally { + ois.close + oos.close + clientSocket.close + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + + // Shutdown the thread pool + threadPool.shutdown + } + } +} diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/DfsBroadcast.scala new file mode 100644 index 0000000000..480d6dd9b1 --- /dev/null +++ b/core/src/main/scala/spark/DfsBroadcast.scala @@ -0,0 +1,132 @@ +package spark + +import java.io._ +import java.net._ +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} + +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} + +@serializable +class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean) +extends Broadcast[T] with Logging { + + def value = value_ + + DfsBroadcast.synchronized { + DfsBroadcast.values.put(uuid, value_) + } + + if (!isLocal) { + sendBroadcast + } + + def sendBroadcast (): Unit = { + val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) + out.writeObject (value_) + out.close + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject + DfsBroadcast.synchronized { + val cachedVal = DfsBroadcast.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + logInfo( "Started reading Broadcasted variable " + uuid) + val start = System.nanoTime + + val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + DfsBroadcast.values.put(uuid, value_) + fileIn.close + + val time = (System.nanoTime - start) / 1e9 + logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } +} + +class DfsBroadcastFactory +extends BroadcastFactory { + def initialize (isMaster: Boolean) = DfsBroadcast.initialize + def newBroadcast[T] (value_ : T, isLocal: Boolean) = + new DfsBroadcast[T] (value_, isLocal) +} + +private object DfsBroadcast +extends Logging { + val values = Cache.newKeySpace() + + private var initialized = false + + private var fileSystem: FileSystem = null + private var workDir: String = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + + def initialize (): Unit = { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val dfs = System.getProperty("spark.dfs", "file:///") + if (!dfs.startsWith("file://")) { + val conf = new Configuration() + conf.setInt("io.file.buffer.size", bufferSize) + val rep = System.getProperty("spark.dfs.replication", "3").toInt + conf.setInt("dfs.replication", rep) + fileSystem = FileSystem.get(new URI(dfs), conf) + } + workDir = System.getProperty("spark.dfs.workDir", "/tmp") + compress = System.getProperty("spark.compress", "false").toBoolean + + initialized = true + } + } + } + + private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) + + def openFileForReading(uuid: UUID): InputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.open(getPath(uuid)) + } else { + // Local filesystem + new FileInputStream(getPath(uuid).toString) + } + + if (compress) { + // LZF stream does its own buffering + new LZFInputStream(fileStream) + } else if (fileSystem == null) { + new BufferedInputStream(fileStream, bufferSize) + } else { + // Hadoop streams do their own buffering + fileStream + } + } + + def openFileForWriting(uuid: UUID): OutputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.create(getPath(uuid)) + } else { + // Local filesystem + new FileOutputStream(getPath(uuid).toString) + } + + if (compress) { + // LZF stream does its own buffering + new LZFOutputStream(fileStream) + } else if (fileSystem == null) { + new BufferedOutputStream(fileStream, bufferSize) + } else { + // Hadoop streams do their own buffering + fileStream + } + } +} diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala index c45eff64d4..6a592d13c3 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -98,7 +98,7 @@ extends MScheduler with spark.Scheduler with Logging params("env." + key) = System.getenv(key) } } - new ExecutorInfo(execScript, createExecArg()) + new ExecutorInfo(execScript, createExecArg(), params) } /** diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 02e80c7756..bf70b5fcb1 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -97,9 +97,9 @@ extends Logging { def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) - // TODO: Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, isLocal) - //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, isLocal) + // Keep around a weak hash map of values to Cached versions? + def broadcast[T](value: T) = + Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal) // Stop the SparkContext def stop() { diff --git a/project/build/SparkProject.scala b/project/build/SparkProject.scala index bf4294cb3e..be4f263891 100644 --- a/project/build/SparkProject.scala +++ b/project/build/SparkProject.scala @@ -18,23 +18,9 @@ extends ParentProject(info) with IdeaProject val TEST_REPORT_DIR = TARGET / "test-report" - val NATIVE_DIR = path("src") / "main" / "native" - - val NATIVE_SOURCES = NATIVE_DIR * "*.c" - - val NATIVE_LIB = { - if (System.getProperty("os.name") == "Mac OS X") - "libspark_native.dylib" - else - "libspark_native.so" - } - - lazy val native = fileTask(TARGET / NATIVE_LIB from NATIVE_SOURCES) { - val makeTarget = " ../../../target/scala_2.8.1/native/" + NATIVE_LIB - (("make -C " + NATIVE_DIR + " " + makeTarget) ! log) - None - }.dependsOn(compile).describedAs("Compiles native library.") - + // Create an XML test report using ScalaTest's -u option. Unfortunately + // there is currently no way to call this directly from SBT without + // executing a subprocess. lazy val testReport = task { log.info("Creating " + TEST_REPORT_DIR + "...") if (!TEST_REPORT_DIR.exists) { @@ -51,6 +51,7 @@ CLASSPATH+=:$CORE_DIR/lib/jetty-7.1.6.v20100715/servlet-api-2.5.jar CLASSPATH+=:$CORE_DIR/lib/apache-log4j-1.2.16/log4j-1.2.16.jar CLASSPATH+=:$CORE_DIR/lib/slf4j-1.6.1/slf4j-api-1.6.1.jar CLASSPATH+=:$CORE_DIR/lib/slf4j-1.6.1/slf4j-log4j12-1.6.1.jar +CLASSPATH+=:$CORE_DIR/lib/compress-lzf-0.6.0/compress-lzf-0.6.0.jar CLASSPATH+=:$EXAMPLES_DIR/target/scala_2.8.1/classes for jar in $CORE_DIR/lib/hadoop-0.20.0/lib/*.jar; do CLASSPATH+=:$jar |