diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-03-29 16:17:55 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-03-29 16:17:55 -0700 |
commit | df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9 (patch) | |
tree | 3f925c0d109b789ce845762a9e09d24329749eb8 /src | |
download | spark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.tar.gz spark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.tar.bz2 spark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.zip |
Initial commit
Diffstat (limited to 'src')
44 files changed, 4342 insertions, 0 deletions
diff --git a/src/examples/CpuHog.scala b/src/examples/CpuHog.scala new file mode 100644 index 0000000000..f37c6f7824 --- /dev/null +++ b/src/examples/CpuHog.scala @@ -0,0 +1,24 @@ +import spark._ + +object CpuHog { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: CpuHog <master> <tasks> <threads_per_task>"); + System.exit(1) + } + val sc = new SparkContext(args(0), "CPU hog") + val tasks = args(1).toInt + val threads = args(2).toInt + def task { + for (i <- 0 until threads-1) { + new Thread() { + override def run { + while(true) {} + } + }.start() + } + while(true) {} + } + sc.runTasks(Array.make(tasks, () => task)) + } +} diff --git a/src/examples/HdfsTest.scala b/src/examples/HdfsTest.scala new file mode 100644 index 0000000000..e678154aab --- /dev/null +++ b/src/examples/HdfsTest.scala @@ -0,0 +1,16 @@ +import spark._ + +object HdfsTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HdfsTest") + val file = sc.textFile(args(1)) + val mapped = file.map(s => s.length).cache() + for (iter <- 1 to 10) { + val start = System.currentTimeMillis() + for (x <- mapped) { x + 2 } + // println("Processing: " + x) + val end = System.currentTimeMillis() + println("Iteration " + iter + " took " + (end-start) + " ms") + } + } +} diff --git a/src/examples/LocalALS.scala b/src/examples/LocalALS.scala new file mode 100644 index 0000000000..17d67b522b --- /dev/null +++ b/src/examples/LocalALS.scala @@ -0,0 +1,118 @@ +import java.util.Random +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ + +object LocalALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return Math.sqrt(sumSqs / (M * U)) + } + + def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each movie that the user rated + for (i <- 0 until M) { + val m = ms(i) + // Add m * m^t to XtX + blas.dger(1, m, m, XtX) + // Add m * rating to Xty + blas.daxpy(R.get(i, j), m, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + args match { + case Array(m, u, f, iters) => { + M = m.toInt + U = u.toInt + F = f.toInt + ITERATIONS = iters.toInt + } + case _ => { + System.err.println("Usage: LocalALS <M> <U> <F> <iters>") + System.exit(1) + } + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fromFunction(_ => factory1D.random(F))(M) + var us = Array.fromFunction(_ => factory1D.random(F))(U) + + // Iteratively update movies then users + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray + us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray + println("RMSE = " + rmse(R, ms, us)) + println() + } + } +} diff --git a/src/examples/LocalFileLR.scala b/src/examples/LocalFileLR.scala new file mode 100644 index 0000000000..60b4aa8fc4 --- /dev/null +++ b/src/examples/LocalFileLR.scala @@ -0,0 +1,36 @@ +import java.util.Random +import Vector._ + +object LocalFileLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + val nums = line.split(' ').map(_.toDouble) + return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0)) + } + + def main(args: Array[String]) { + val lines = scala.io.Source.fromFile(args(0)).getLines + val points = lines.map(parsePoint _) + val ITERATIONS = args(1).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- points) { + val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/src/examples/LocalLR.scala b/src/examples/LocalLR.scala new file mode 100644 index 0000000000..175907e551 --- /dev/null +++ b/src/examples/LocalLR.scala @@ -0,0 +1,41 @@ +import java.util.Random +import Vector._ + +object LocalLR { + val N = 10000 // Number of data points + val D = 10 // Numer of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.fromFunction(generatePoint _)(N) + } + + def main(args: Array[String]) { + val data = generateData + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- data) { + val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/src/examples/LocalPi.scala b/src/examples/LocalPi.scala new file mode 100644 index 0000000000..c83aeed40b --- /dev/null +++ b/src/examples/LocalPi.scala @@ -0,0 +1,14 @@ +import spark._ +import SparkContext._ + +object LocalPi { + def main(args: Array[String]) { + var count = 0 + for (i <- 1 to 100000) { + val x = Math.random * 2 - 1 + val y = Math.random * 2 - 1 + if (x*x + y*y < 1) count += 1 + } + println("Pi is roughly " + 4 * count / 100000.0) + } +}
\ No newline at end of file diff --git a/src/examples/SleepJob.scala b/src/examples/SleepJob.scala new file mode 100644 index 0000000000..a5e0ea0dc2 --- /dev/null +++ b/src/examples/SleepJob.scala @@ -0,0 +1,19 @@ +import spark._ + +object SleepJob { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: SleepJob <master> <tasks> <task_duration>"); + System.exit(1) + } + val sc = new SparkContext(args(0), "Sleep job") + val tasks = args(1).toInt + val duration = args(2).toInt + def task { + val start = System.currentTimeMillis + while (System.currentTimeMillis - start < duration * 1000L) + Thread.sleep(200) + } + sc.runTasks(Array.make(tasks, () => task)) + } +} diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala new file mode 100644 index 0000000000..2fd58ed3a5 --- /dev/null +++ b/src/examples/SparkALS.scala @@ -0,0 +1,138 @@ +import java.io.Serializable +import java.util.Random +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ +import spark._ + +object SparkALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return Math.sqrt(sumSqs / (M * U)) + } + + def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val U = us.size + val F = us(0).size + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val M = ms.size + val F = ms(0).size + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each movie that the user rated + for (i <- 0 until M) { + val m = ms(i) + // Add m * m^t to XtX + blas.dger(1, m, m, XtX) + // Add m * rating to Xty + blas.daxpy(R.get(i, j), m, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + var host = "" + var slices = 0 + args match { + case Array(m, u, f, iters, slices_, host_) => { + M = m.toInt + U = u.toInt + F = f.toInt + ITERATIONS = iters.toInt + slices = slices_.toInt + host = host_ + } + case _ => { + System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>") + System.exit(1) + } + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + val spark = new SparkContext(host, "SparkALS") + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fromFunction(_ => factory1D.random(F))(M) + var us = Array.fromFunction(_ => factory1D.random(F))(U) + + // Iteratively update movies then users + val Rc = spark.broadcast(R) + var msb = spark.broadcast(ms) + var usb = spark.broadcast(us) + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = spark.parallelize(0 until M, slices) + .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value)) + .toArray + msb = spark.broadcast(ms) // Re-broadcast ms because it was updated + us = spark.parallelize(0 until U, slices) + .map(i => updateUser(i, usb.value(i), msb.value, Rc.value)) + .toArray + usb = spark.broadcast(us) // Re-broadcast us because it was updated + println("RMSE = " + rmse(R, ms, us)) + println() + } + } +} diff --git a/src/examples/SparkHdfsLR.scala b/src/examples/SparkHdfsLR.scala new file mode 100644 index 0000000000..d0400380bd --- /dev/null +++ b/src/examples/SparkHdfsLR.scala @@ -0,0 +1,50 @@ +import java.util.Random +import Vector._ +import spark._ + +object SparkHdfsLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + //val nums = line.split(' ').map(_.toDouble) + //return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0)) + val tok = new java.util.StringTokenizer(line, " ") + var y = tok.nextToken.toDouble + var x = new Array[Double](D) + var i = 0 + while (i < D) { + x(i) = tok.nextToken.toDouble; i += 1 + } + return DataPoint(new Vector(x), y) + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: SparkHdfsLR <master> <file> <iters>") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkHdfsLR") + val lines = sc.textFile(args(1)) + val points = lines.map(parsePoint _).cache() + val ITERATIONS = args(2).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = sc.accumulator(Vector.zeros(D)) + for (p <- points) { + val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient.value + } + + println("Final w: " + w) + } +} diff --git a/src/examples/SparkLR.scala b/src/examples/SparkLR.scala new file mode 100644 index 0000000000..34574f5640 --- /dev/null +++ b/src/examples/SparkLR.scala @@ -0,0 +1,48 @@ +import java.util.Random +import Vector._ +import spark._ + +object SparkLR { + val N = 10000 // Number of data points + val D = 10 // Numer of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.fromFunction(generatePoint _)(N) + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkLR <host> [<slices>]") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLR") + val numSlices = if (args.length > 1) args(1).toInt else 2 + val data = generateData + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = sc.accumulator(Vector.zeros(D)) + for (p <- sc.parallelize(data, numSlices)) { + val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient.value + } + + println("Final w: " + w) + } +} diff --git a/src/examples/SparkPi.scala b/src/examples/SparkPi.scala new file mode 100644 index 0000000000..7dbadd1088 --- /dev/null +++ b/src/examples/SparkPi.scala @@ -0,0 +1,20 @@ +import spark._ +import SparkContext._ + +object SparkPi { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkLR <host> [<slices>]") + System.exit(1) + } + val spark = new SparkContext(args(0), "SparkPi") + val slices = if (args.length > 1) args(1).toInt else 2 + var count = spark.accumulator(0) + for (i <- spark.parallelize(1 to 100000, slices)) { + val x = Math.random * 2 - 1 + val y = Math.random * 2 - 1 + if (x*x + y*y < 1) count += 1 + } + println("Pi is roughly " + 4 * count.value / 100000.0) + } +}
\ No newline at end of file diff --git a/src/examples/Vector.scala b/src/examples/Vector.scala new file mode 100644 index 0000000000..0ae2cbc6e8 --- /dev/null +++ b/src/examples/Vector.scala @@ -0,0 +1,63 @@ +@serializable class Vector(val elements: Array[Double]) { + def length = elements.length + + def apply(index: Int) = elements(index) + + def + (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) + other(i)) + } + + def - (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) - other(i)) + } + + def dot(other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + for (i <- 0 until length) + ans += this(i) * other(i) + return ans + } + + def * ( scale: Double): Vector = Vector(length, i => this(i) * scale) + + def unary_- = this * -1 + + def sum = elements.reduceLeft(_ + _) + + override def toString = elements.mkString("(", ", ", ")") + +} + +object Vector { + def apply(elements: Array[Double]) = new Vector(elements) + + def apply(elements: Double*) = new Vector(elements.toArray) + + def apply(length: Int, initializer: Int => Double): Vector = { + val elements = new Array[Double](length) + for (i <- 0 until length) + elements(i) = initializer(i) + return new Vector(elements) + } + + def zeros(length: Int) = new Vector(new Array[Double](length)) + + def ones(length: Int) = Vector(length, _ => 1) + + class Multiplier(num: Double) { + def * (vec: Vector) = vec * num + } + + implicit def doubleToMultiplier(num: Double) = new Multiplier(num) + + implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { + def add(t1: Vector, t2: Vector) = t1 + t2 + def zero(initialValue: Vector) = Vector.zeros(initialValue.length) + } +} diff --git a/src/java/spark/compress/lzf/LZF.java b/src/java/spark/compress/lzf/LZF.java new file mode 100644 index 0000000000..294a0494ec --- /dev/null +++ b/src/java/spark/compress/lzf/LZF.java @@ -0,0 +1,27 @@ +package spark.compress.lzf; + +public class LZF { + private static boolean loaded; + + static { + try { + System.loadLibrary("spark_native"); + loaded = true; + } catch(Throwable t) { + System.out.println("Failed to load native LZF library: " + t.toString()); + loaded = false; + } + } + + public static boolean isLoaded() { + return loaded; + } + + public static native int compress( + byte[] in, int inOff, int inLen, + byte[] out, int outOff, int outLen); + + public static native int decompress( + byte[] in, int inOff, int inLen, + byte[] out, int outOff, int outLen); +} diff --git a/src/java/spark/compress/lzf/LZFInputStream.java b/src/java/spark/compress/lzf/LZFInputStream.java new file mode 100644 index 0000000000..16bc687489 --- /dev/null +++ b/src/java/spark/compress/lzf/LZFInputStream.java @@ -0,0 +1,180 @@ +package spark.compress.lzf; + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +public class LZFInputStream extends FilterInputStream { + private static final int MAX_BLOCKSIZE = 1024 * 64 - 1; + private static final int MAX_HDR_SIZE = 7; + + private byte[] inBuf; // Holds data to decompress (including header) + private byte[] outBuf; // Holds decompressed data to output + private int outPos; // Current position in outBuf + private int outSize; // Total amount of data in outBuf + + private boolean closed; + private boolean reachedEof; + + private byte[] singleByte = new byte[1]; + + public LZFInputStream(InputStream in) { + super(in); + if (in == null) + throw new NullPointerException(); + inBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; + outBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; + outPos = 0; + outSize = 0; + } + + private void ensureOpen() throws IOException { + if (closed) throw new IOException("Stream closed"); + } + + @Override + public int read() throws IOException { + ensureOpen(); + int count = read(singleByte, 0, 1); + return (count == -1 ? -1 : singleByte[0] & 0xFF); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + if ((off | len | (off + len) | (b.length - (off + len))) < 0) + throw new IndexOutOfBoundsException(); + + int totalRead = 0; + + // Start with the current block in outBuf, and read and decompress any + // further blocks necessary. Instead of trying to decompress directly to b + // when b is large, we always use outBuf as an intermediate holding space + // in case GetPrimitiveArrayCritical decides to copy arrays instead of + // pinning them, which would cause b to be copied repeatedly into C-land. + while (len > 0) { + if (outPos == outSize) { + readNextBlock(); + if (reachedEof) + return totalRead == 0 ? -1 : totalRead; + } + int amtToCopy = Math.min(outSize - outPos, len); + System.arraycopy(outBuf, outPos, b, off, amtToCopy); + off += amtToCopy; + len -= amtToCopy; + outPos += amtToCopy; + totalRead += amtToCopy; + } + + return totalRead; + } + + // Read len bytes from this.in to a buffer, stopping only if EOF is reached + private int readFully(byte[] b, int off, int len) throws IOException { + int totalRead = 0; + while (len > 0) { + int amt = in.read(b, off, len); + if (amt == -1) + break; + off += amt; + len -= amt; + totalRead += amt; + } + return totalRead; + } + + // Read the next block from the underlying InputStream into outBuf, + // setting outPos and outSize, or set reachedEof if the stream ends. + private void readNextBlock() throws IOException { + // Read first 5 bytes of header + int count = readFully(inBuf, 0, 5); + if (count == 0) { + reachedEof = true; + return; + } else if (count < 5) { + throw new EOFException("Truncated LZF block header"); + } + + // Check magic bytes + if (inBuf[0] != 'Z' || inBuf[1] != 'V') + throw new IOException("Wrong magic bytes in LZF block header"); + + // Read the block + if (inBuf[2] == 0) { + // Uncompressed block - read directly to outBuf + int size = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); + if (readFully(outBuf, 0, size) != size) + throw new EOFException("EOF inside LZF block"); + outPos = 0; + outSize = size; + } else if (inBuf[2] == 1) { + // Compressed block - read to inBuf and decompress + if (readFully(inBuf, 5, 2) != 2) + throw new EOFException("Truncated LZF block header"); + int csize = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); + int usize = ((inBuf[5] & 0xFF) << 8) | (inBuf[6] & 0xFF); + if (readFully(inBuf, 7, csize) != csize) + throw new EOFException("Truncated LZF block"); + if (LZF.decompress(inBuf, 7, csize, outBuf, 0, usize) != usize) + throw new IOException("Corrupt LZF data stream"); + outPos = 0; + outSize = usize; + } else { + throw new IOException("Unknown block type in LZF block header"); + } + } + + /** + * Returns 0 after EOF has been reached, otherwise always return 1. + * + * Programs should not count on this method to return the actual number + * of bytes that could be read without blocking. + */ + @Override + public int available() throws IOException { + ensureOpen(); + return reachedEof ? 0 : 1; + } + + // TODO: Skip complete chunks without decompressing them? + @Override + public long skip(long n) throws IOException { + ensureOpen(); + if (n < 0) + throw new IllegalArgumentException("negative skip length"); + byte[] buf = new byte[512]; + long skipped = 0; + while (skipped < n) { + int len = (int) Math.min(n - skipped, buf.length); + len = read(buf, 0, len); + if (len == -1) { + reachedEof = true; + break; + } + skipped += len; + } + return skipped; + } + + @Override + public void close() throws IOException { + if (!closed) { + in.close(); + closed = true; + } + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readLimit) {} + + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } +} diff --git a/src/java/spark/compress/lzf/LZFOutputStream.java b/src/java/spark/compress/lzf/LZFOutputStream.java new file mode 100644 index 0000000000..5f65e95d2a --- /dev/null +++ b/src/java/spark/compress/lzf/LZFOutputStream.java @@ -0,0 +1,85 @@ +package spark.compress.lzf; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class LZFOutputStream extends FilterOutputStream { + private static final int BLOCKSIZE = 1024 * 64 - 1; + private static final int MAX_HDR_SIZE = 7; + + private byte[] inBuf; // Holds input data to be compressed + private byte[] outBuf; // Holds compressed data to be written + private int inPos; // Current position in inBuf + + public LZFOutputStream(OutputStream out) { + super(out); + inBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; + outBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; + inPos = MAX_HDR_SIZE; + } + + @Override + public void write(int b) throws IOException { + inBuf[inPos++] = (byte) b; + if (inPos == inBuf.length) + compressAndSendBlock(); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if ((off | len | (off + len) | (b.length - (off + len))) < 0) + throw new IndexOutOfBoundsException(); + + // If we're given a large array, copy it piece by piece into inBuf and + // write one BLOCKSIZE at a time. This is done to prevent the JNI code + // from copying the whole array repeatedly if GetPrimitiveArrayCritical + // decides to copy instead of pinning. + while (inPos + len >= inBuf.length) { + int amtToCopy = inBuf.length - inPos; + System.arraycopy(b, off, inBuf, inPos, amtToCopy); + inPos += amtToCopy; + compressAndSendBlock(); + off += amtToCopy; + len -= amtToCopy; + } + + // Copy the remaining (incomplete) block into inBuf + System.arraycopy(b, off, inBuf, inPos, len); + inPos += len; + } + + @Override + public void flush() throws IOException { + if (inPos > MAX_HDR_SIZE) + compressAndSendBlock(); + out.flush(); + } + + // Send the data in inBuf, and reset inPos to start writing a new block. + private void compressAndSendBlock() throws IOException { + int us = inPos - MAX_HDR_SIZE; + int maxcs = us > 4 ? us - 4 : us; + int cs = LZF.compress(inBuf, MAX_HDR_SIZE, us, outBuf, MAX_HDR_SIZE, maxcs); + if (cs != 0) { + // Compression made the data smaller; use type 1 header + outBuf[0] = 'Z'; + outBuf[1] = 'V'; + outBuf[2] = 1; + outBuf[3] = (byte) (cs >> 8); + outBuf[4] = (byte) (cs & 0xFF); + outBuf[5] = (byte) (us >> 8); + outBuf[6] = (byte) (us & 0xFF); + out.write(outBuf, 0, 7 + cs); + } else { + // Compression didn't help; use type 0 header and uncompressed data + inBuf[2] = 'Z'; + inBuf[3] = 'V'; + inBuf[4] = 0; + inBuf[5] = (byte) (us >> 8); + inBuf[6] = (byte) (us & 0xFF); + out.write(inBuf, 2, 5 + us); + } + inPos = MAX_HDR_SIZE; + } +} diff --git a/src/native/Makefile b/src/native/Makefile new file mode 100644 index 0000000000..331d3e6057 --- /dev/null +++ b/src/native/Makefile @@ -0,0 +1,30 @@ +CC = gcc +#JAVA_HOME = /usr/lib/jvm/java-6-sun +OS_NAME = linux + +CFLAGS = -fPIC -O3 -funroll-all-loops + +SPARK = ../.. + +LZF = $(SPARK)/third_party/liblzf-3.5 + +LIB = libspark_native.so + +all: $(LIB) + +spark_compress_lzf_LZF.h: $(SPARK)/classes/spark/compress/lzf/LZF.class +ifeq ($(JAVA_HOME),) + $(error JAVA_HOME is not set) +else + $(JAVA_HOME)/bin/javah -classpath $(SPARK)/classes spark.compress.lzf.LZF +endif + +$(LIB): spark_compress_lzf_LZF.h spark_compress_lzf_LZF.c + $(CC) $(CFLAGS) -shared -o $@ spark_compress_lzf_LZF.c \ + -I $(JAVA_HOME)/include -I $(JAVA_HOME)/include/$(OS_NAME) \ + -I $(LZF) $(LZF)/lzf_c.c $(LZF)/lzf_d.c + +clean: + rm -f spark_compress_lzf_LZF.h $(LIB) + +.PHONY: all clean diff --git a/src/native/spark_compress_lzf_LZF.c b/src/native/spark_compress_lzf_LZF.c new file mode 100644 index 0000000000..c2a59def3e --- /dev/null +++ b/src/native/spark_compress_lzf_LZF.c @@ -0,0 +1,90 @@ +#include "spark_compress_lzf_LZF.h" +#include <lzf.h> + + +/* Helper function to throw an exception */ +static void throwException(JNIEnv *env, const char* className) { + jclass cls = (*env)->FindClass(env, className); + if (cls != 0) /* If cls is null, an exception was already thrown */ + (*env)->ThrowNew(env, cls, ""); +} + + +/* + * Since LZF.compress() and LZF.decompress() have the same signatures + * and differ only in which lzf_ function they call, implement both in a + * single function and pass it a pointer to the correct lzf_ function. + */ +static jint callCompressionFunction + (unsigned int (*func)(const void *const, unsigned int, void *, unsigned int), + JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, + jbyteArray outArray, jint outOff, jint outLen) +{ + jint inCap; + jint outCap; + jbyte *inData = 0; + jbyte *outData = 0; + jint ret; + jint s; + + if (!inArray || !outArray) { + throwException(env, "java/lang/NullPointerException"); + goto cleanup; + } + + inCap = (*env)->GetArrayLength(env, inArray); + outCap = (*env)->GetArrayLength(env, outArray); + + // Check if any of the offset/length pairs is invalid; we do this by OR'ing + // things we don't want to be negative and seeing if the result is negative + s = inOff | inLen | (inOff + inLen) | (inCap - (inOff + inLen)) | + outOff | outLen | (outOff + outLen) | (outCap - (outOff + outLen)); + if (s < 0) { + throwException(env, "java/lang/IndexOutOfBoundsException"); + goto cleanup; + } + + inData = (*env)->GetPrimitiveArrayCritical(env, inArray, 0); + outData = (*env)->GetPrimitiveArrayCritical(env, outArray, 0); + + if (!inData || !outData) { + // Out of memory - JVM will throw OutOfMemoryError + goto cleanup; + } + + ret = func(inData + inOff, inLen, outData + outOff, outLen); + +cleanup: + if (inData) + (*env)->ReleasePrimitiveArrayCritical(env, inArray, inData, 0); + if (outData) + (*env)->ReleasePrimitiveArrayCritical(env, outArray, outData, 0); + + return ret; +} + +/* + * Class: spark_compress_lzf_LZF + * Method: compress + * Signature: ([B[B)I + */ +JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_compress + (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, + jbyteArray outArray, jint outOff, jint outLen) +{ + return callCompressionFunction(lzf_compress, env, cls, + inArray, inOff, inLen, outArray,outOff, outLen); +} + +/* + * Class: spark_compress_lzf_LZF + * Method: decompress + * Signature: ([B[B)I + */ +JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_decompress + (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, + jbyteArray outArray, jint outOff, jint outLen) +{ + return callCompressionFunction(lzf_decompress, env, cls, + inArray, inOff, inLen, outArray,outOff, outLen); +} diff --git a/src/scala/spark/Accumulators.scala b/src/scala/spark/Accumulators.scala new file mode 100644 index 0000000000..3e4cd4935a --- /dev/null +++ b/src/scala/spark/Accumulators.scala @@ -0,0 +1,71 @@ +package spark + +import java.io._ + +import scala.collection.mutable.Map + +@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) +{ + val id = Accumulators.newId + @transient var value_ = initialValue + var deserialized = false + + Accumulators.register(this) + + def += (term: T) { value_ = param.add(value_, term) } + def value = this.value_ + def value_= (t: T) { + if (!deserialized) value_ = t + else throw new UnsupportedOperationException("Can't use value_= in task") + } + + // Called by Java when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject + value_ = param.zero(initialValue) + deserialized = true + Accumulators.register(this) + } + + override def toString = value_.toString +} + +@serializable trait AccumulatorParam[T] { + def add(t1: T, t2: T): T + def zero(initialValue: T): T +} + +// TODO: The multi-thread support in accumulators is kind of lame; check +// if there's a more intuitive way of doing it right +private object Accumulators +{ + // TODO: Use soft references? => need to make readObject work properly then + val accums = Map[(Thread, Long), Accumulator[_]]() + var lastId: Long = 0 + + def newId: Long = synchronized { lastId += 1; return lastId } + + def register(a: Accumulator[_]): Unit = synchronized { + accums((currentThread, a.id)) = a + } + + def clear: Unit = synchronized { + accums.retain((key, accum) => key._1 != currentThread) + } + + def values: Map[Long, Any] = synchronized { + val ret = Map[Long, Any]() + for(((thread, id), accum) <- accums if thread == currentThread) + ret(id) = accum.value + return ret + } + + def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized { + for ((id, value) <- values) { + if (accums.contains((thread, id))) { + val accum = accums((thread, id)) + accum.asInstanceOf[Accumulator[Any]] += value + } + } + } +} diff --git a/src/scala/spark/Cached.scala b/src/scala/spark/Cached.scala new file mode 100644 index 0000000000..8113340e1f --- /dev/null +++ b/src/scala/spark/Cached.scala @@ -0,0 +1,110 @@ +package spark + +import java.io._ +import java.net.URI +import java.util.UUID + +import com.google.common.collect.MapMaker + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} + +import spark.compress.lzf.{LZFInputStream, LZFOutputStream} + +@serializable class Cached[T](@transient var value_ : T, local: Boolean) { + val uuid = UUID.randomUUID() + def value = value_ + + Cache.synchronized { Cache.values.put(uuid, value_) } + + if (!local) writeCacheFile() + + private def writeCacheFile() { + val out = new ObjectOutputStream(Cache.openFileForWriting(uuid)) + out.writeObject(value_) + out.close() + } + + // Called by Java when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject + Cache.synchronized { + val cachedVal = Cache.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + val start = System.nanoTime + val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid)) + value_ = fileIn.readObject().asInstanceOf[T] + Cache.values.put(uuid, value_) + fileIn.close() + val time = (System.nanoTime - start) / 1e9 + println("Reading cached variable " + uuid + " took " + time + " s") + } + } + } + + override def toString = "spark.Cached(" + uuid + ")" +} + +private object Cache { + val values = new MapMaker().softValues().makeMap[UUID, Any]() + + private var initialized = false + private var fileSystem: FileSystem = null + private var workDir: String = null + private var compress: Boolean = false + private var bufferSize: Int = 65536 + + // Will be called by SparkContext or Executor before using cache + def initialize() { + synchronized { + if (!initialized) { + bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val dfs = System.getProperty("spark.dfs", "file:///") + if (!dfs.startsWith("file://")) { + val conf = new Configuration() + conf.setInt("io.file.buffer.size", bufferSize) + val rep = System.getProperty("spark.dfs.replication", "3").toInt + conf.setInt("dfs.replication", rep) + fileSystem = FileSystem.get(new URI(dfs), conf) + } + workDir = System.getProperty("spark.dfs.workdir", "/tmp") + compress = System.getProperty("spark.compress", "false").toBoolean + initialized = true + } + } + } + + private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid) + + def openFileForReading(uuid: UUID): InputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.open(getPath(uuid)) + } else { + // Local filesystem + new FileInputStream(getPath(uuid).toString) + } + if (compress) + new LZFInputStream(fileStream) // LZF stream does its own buffering + else if (fileSystem == null) + new BufferedInputStream(fileStream, bufferSize) + else + fileStream // Hadoop streams do their own buffering + } + + def openFileForWriting(uuid: UUID): OutputStream = { + val fileStream = if (fileSystem != null) { + fileSystem.create(getPath(uuid)) + } else { + // Local filesystem + new FileOutputStream(getPath(uuid).toString) + } + if (compress) + new LZFOutputStream(fileStream) // LZF stream does its own buffering + else if (fileSystem == null) + new BufferedOutputStream(fileStream, bufferSize) + else + fileStream // Hadoop streams do their own buffering + } +} diff --git a/src/scala/spark/ClosureCleaner.scala b/src/scala/spark/ClosureCleaner.scala new file mode 100644 index 0000000000..c5663901b3 --- /dev/null +++ b/src/scala/spark/ClosureCleaner.scala @@ -0,0 +1,157 @@ +package spark + +import scala.collection.mutable.Map +import scala.collection.mutable.Set + +import org.objectweb.asm.{ClassReader, MethodVisitor, Type} +import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.Opcodes._ + + +object ClosureCleaner { + private def getClassReader(cls: Class[_]): ClassReader = new ClassReader( + cls.getResourceAsStream(cls.getName.replaceFirst("^.*\\.", "") + ".class")) + + private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + return f.getType :: getOuterClasses(f.get(obj)) + } + return Nil + } + + private def getOuterObjects(obj: AnyRef): List[AnyRef] = { + for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { + f.setAccessible(true) + return f.get(obj) :: getOuterObjects(f.get(obj)) + } + return Nil + } + + private def getInnerClasses(obj: AnyRef): List[Class[_]] = { + val seen = Set[Class[_]](obj.getClass) + var stack = List[Class[_]](obj.getClass) + while (!stack.isEmpty) { + val cr = getClassReader(stack.head) + stack = stack.tail + val set = Set[Class[_]]() + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack = cls :: stack + } + } + return (seen - obj.getClass).toList + } + + private def createNullValue(cls: Class[_]): AnyRef = { + if (cls.isPrimitive) + new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + else + null + } + + def clean(func: AnyRef): Unit = { + // TODO: cache outerClasses / innerClasses / accessedFields + val outerClasses = getOuterClasses(func) + val innerClasses = getInnerClasses(func) + val outerObjects = getOuterObjects(func) + + val accessedFields = Map[Class[_], Set[String]]() + for (cls <- outerClasses) + accessedFields(cls) = Set[String]() + for (cls <- func.getClass :: innerClasses) + getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + + var outer: AnyRef = null + for ((cls, obj) <- (outerClasses zip outerObjects).reverse) { + outer = instantiateClass(cls, outer); + for (fieldName <- accessedFields(cls)) { + val field = cls.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + //println("1: Setting " + fieldName + " on " + cls + " to " + value); + field.set(outer, value) + } + } + + if (outer != null) { + //println("2: Setting $outer on " + func.getClass + " to " + outer); + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + field.set(func, outer) + } + } + + private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = { + if (spark.repl.Main.interp == null) { + // This is a bona fide closure class, whose constructor has no effects + // other than to set its fields, so use its constructor + val cons = cls.getConstructors()(0) + val params = cons.getParameterTypes.map(createNullValue).toArray + if (outer != null) + params(0) = outer // First param is always outer object + return cons.newInstance(params: _*).asInstanceOf[AnyRef] + } else { + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory(); + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor(); + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef]; + if (outer != null) { + //println("3: Setting $outer on " + cls + " to " + outer); + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, outer) + } + return obj + } + } +} + + +class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new EmptyVisitor { + override def visitFieldInsn(op: Int, owner: String, name: String, + desc: String) { + if (op == GETFIELD) + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) + output(cl) += name + } + + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) + for (cl <- output.keys if cl.getName == owner.replace('/', '.')) + output(cl) += name + } + } + } +} + + +class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { + var myName: String = null + + override def visit(version: Int, access: Int, name: String, sig: String, + superName: String, interfaces: Array[String]) { + myName = name + } + + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + return new EmptyVisitor { + override def visitMethodInsn(op: Int, owner: String, name: String, + desc: String) { + val argTypes = Type.getArgumentTypes(desc) + if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 + && argTypes(0).toString.startsWith("L") // is it an object? + && argTypes(0).getInternalName == myName) + output += Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + } + } + } +} diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala new file mode 100644 index 0000000000..4cc8f00aa9 --- /dev/null +++ b/src/scala/spark/Executor.scala @@ -0,0 +1,70 @@ +package spark + +import java.util.concurrent.{Executors, ExecutorService} + +import nexus.{ExecutorArgs, TaskDescription, TaskState, TaskStatus} + +object Executor { + def main(args: Array[String]) { + System.loadLibrary("nexus") + + val exec = new nexus.Executor() { + var classLoader: ClassLoader = null + var threadPool: ExecutorService = null + + override def init(args: ExecutorArgs) { + // Read spark.* system properties + val props = Utils.deserialize[Array[(String, String)]](args.getData) + for ((key, value) <- props) + System.setProperty(key, value) + + // Initialize cache (uses some properties read above) + Cache.initialize() + + // If the REPL is in use, create a ClassLoader that will be able to + // read new classes defined by the REPL as the user types code + classLoader = this.getClass.getClassLoader + val classDir = System.getProperty("spark.repl.classdir") + if (classDir != null) { + println("Using REPL classdir: " + classDir) + classLoader = new repl.ExecutorClassLoader(classDir, classLoader) + } + Thread.currentThread.setContextClassLoader(classLoader) + + // Start worker thread pool (they will inherit our context ClassLoader) + threadPool = Executors.newCachedThreadPool() + } + + override def startTask(desc: TaskDescription) { + // Pull taskId and arg out of TaskDescription because it won't be a + // valid pointer after this method call (TODO: fix this in C++/SWIG) + val taskId = desc.getTaskId + val arg = desc.getArg + threadPool.execute(new Runnable() { + def run() = { + println("Running task ID " + taskId) + try { + Accumulators.clear + val task = Utils.deserialize[Task[Any]](arg, classLoader) + val value = task.run + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates) + sendStatusUpdate(new TaskStatus( + taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) + println("Finished task ID " + taskId) + } catch { + case e: Exception => { + // TODO: Handle errors in tasks less dramatically + System.err.println("Exception in task ID " + taskId + ":") + e.printStackTrace + System.exit(1) + } + } + } + }) + } + } + + exec.run() + } +} diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala new file mode 100644 index 0000000000..8050683f99 --- /dev/null +++ b/src/scala/spark/HdfsFile.scala @@ -0,0 +1,277 @@ +package spark + +import java.io._ +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentHashMap +import java.util.HashSet + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Map + +import nexus._ + +import com.google.common.collect.MapMaker + +import org.apache.hadoop.io.ObjectWritable +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputSplit +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapred.RecordReader +import org.apache.hadoop.mapred.Reporter + +@serializable +abstract class DistributedFile[T, Split](@transient sc: SparkContext) { + def splits: Array[Split] + def iterator(split: Split): Iterator[T] + def prefers(split: Split, slot: SlaveOffer): Boolean + + def taskStarted(split: Split, slot: SlaveOffer) {} + + def sparkContext = sc + + def foreach(f: T => Unit) { + val cleanF = sc.clean(f) + val tasks = splits.map(s => new ForeachTask(this, s, cleanF)).toArray + sc.runTaskObjects(tasks) + } + + def toArray: Array[T] = { + val tasks = splits.map(s => new GetTask(this, s)) + val results = sc.runTaskObjects(tasks) + Array.concat(results: _*) + } + + def reduce(f: (T, T) => T): T = { + val cleanF = sc.clean(f) + val tasks = splits.map(s => new ReduceTask(this, s, f)) + val results = new ArrayBuffer[T] + for (option <- sc.runTaskObjects(tasks); elem <- option) + results += elem + if (results.size == 0) + throw new UnsupportedOperationException("empty collection") + else + return results.reduceLeft(f) + } + + def take(num: Int): Array[T] = { + if (num == 0) + return new Array[T](0) + val buf = new ArrayBuffer[T] + for (split <- splits; elem <- iterator(split)) { + buf += elem + if (buf.length == num) + return buf.toArray + } + return buf.toArray + } + + def first: T = take(1) match { + case Array(t) => t + case _ => throw new UnsupportedOperationException("empty collection") + } + + def map[U](f: T => U) = new MappedFile(this, sc.clean(f)) + def filter(f: T => Boolean) = new FilteredFile(this, sc.clean(f)) + def cache() = new CachedFile(this) + + def count(): Long = + try { map(x => 1L).reduce(_+_) } + catch { case e: UnsupportedOperationException => 0L } +} + +@serializable +abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split], + val split: Split) +extends Task[U] { + override def prefers(slot: SlaveOffer) = file.prefers(split, slot) + override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) } +} + +class ForeachTask[T, Split](file: DistributedFile[T, Split], + split: Split, func: T => Unit) +extends FileTask[Unit, T, Split](file, split) { + override def run() { + println("Processing " + split) + file.iterator(split).foreach(func) + } +} + +class GetTask[T, Split](file: DistributedFile[T, Split], split: Split) +extends FileTask[Array[T], T, Split](file, split) { + override def run(): Array[T] = { + println("Processing " + split) + file.iterator(split).collect.toArray + } +} + +class ReduceTask[T, Split](file: DistributedFile[T, Split], + split: Split, f: (T, T) => T) +extends FileTask[Option[T], T, Split](file, split) { + override def run(): Option[T] = { + println("Processing " + split) + val iter = file.iterator(split) + if (iter.hasNext) + Some(iter.reduceLeft(f)) + else + None + } +} + +class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U) +extends DistributedFile[U, Split](prev.sparkContext) { + override def splits = prev.splits + override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) + override def iterator(split: Split) = prev.iterator(split).map(f) + override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) +} + +class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean) +extends DistributedFile[T, Split](prev.sparkContext) { + override def splits = prev.splits + override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) + override def iterator(split: Split) = prev.iterator(split).filter(f) + override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) +} + +class CachedFile[T, Split](prev: DistributedFile[T, Split]) +extends DistributedFile[T, Split](prev.sparkContext) { + val id = CachedFile.newId() + @transient val cacheLocs = Map[Split, List[Int]]() + + override def splits = prev.splits + + override def prefers(split: Split, slot: SlaveOffer): Boolean = { + if (cacheLocs.contains(split)) + cacheLocs(split).contains(slot.getSlaveId) + else + prev.prefers(split, slot) + } + + override def iterator(split: Split): Iterator[T] = { + val key = id + "::" + split.toString + val cache = CachedFile.cache + val loading = CachedFile.loading + val cachedVal = cache.get(key) + if (cachedVal != null) { + // Split is in cache, so just return its values + return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]]) + } else { + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + return Iterator.fromArray(cache.get(key).asInstanceOf[Array[T]]) + } else { + loading.add(key) + } + } + // If we got here, we have to load the split + println("Loading and caching " + split) + val array = prev.iterator(split).collect.toArray + cache.put(key, array) + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + return Iterator.fromArray(array) + } + } + + override def taskStarted(split: Split, slot: SlaveOffer) { + val oldList = cacheLocs.getOrElse(split, Nil) + val slaveId = slot.getSlaveId + if (!oldList.contains(slaveId)) + cacheLocs(split) = slaveId :: oldList + } +} + +private object CachedFile { + val nextId = new AtomicLong(0) // Generates IDs for mapped files (on master) + def newId() = nextId.getAndIncrement() + + // Stores map results for various splits locally (on workers) + val cache = new MapMaker().softValues().makeMap[String, AnyRef]() + + // Remembers which splits are currently being loaded (on workers) + val loading = new HashSet[String] +} + +class HdfsSplit(@transient s: InputSplit) +extends SerializableWritable[InputSplit](s) + +class HdfsTextFile(sc: SparkContext, path: String) +extends DistributedFile[String, HdfsSplit](sc) { + @transient val conf = new JobConf() + @transient val inputFormat = new TextInputFormat() + + FileInputFormat.setInputPaths(conf, path) + ConfigureLock.synchronized { inputFormat.configure(conf) } + + @transient val splits_ = + inputFormat.getSplits(conf, 2).map(new HdfsSplit(_)).toArray + + override def splits = splits_ + + override def iterator(split: HdfsSplit) = new Iterator[String] { + var reader: RecordReader[LongWritable, Text] = null + ConfigureLock.synchronized { + val conf = new JobConf() + conf.set("io.file.buffer.size", + System.getProperty("spark.buffer.size", "65536")) + val tif = new TextInputFormat() + tif.configure(conf) + reader = tif.getRecordReader(split.value, conf, Reporter.NULL) + } + val lineNum = new LongWritable() + val text = new Text() + var gotNext = false + var finished = false + + override def hasNext: Boolean = { + if (!gotNext) { + finished = !reader.next(lineNum, text) + gotNext = true + } + !finished + } + + override def next: String = { + if (!gotNext) + finished = !reader.next(lineNum, text) + if (finished) + throw new java.util.NoSuchElementException("end of stream") + gotNext = false + text.toString + } + } + + override def prefers(split: HdfsSplit, slot: SlaveOffer) = + split.value.getLocations().contains(slot.getHost) +} + +object ConfigureLock {} + +@serializable +class SerializableWritable[T <: Writable](@transient var t: T) { + def value = t + override def toString = t.toString + + private def writeObject(out: ObjectOutputStream) { + out.defaultWriteObject() + new ObjectWritable(t).write(out) + } + + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + val ow = new ObjectWritable() + ow.setConf(new JobConf()) + ow.readFields(in) + t = ow.get().asInstanceOf[T] + } +} diff --git a/src/scala/spark/LocalScheduler.scala b/src/scala/spark/LocalScheduler.scala new file mode 100644 index 0000000000..35bfdde09f --- /dev/null +++ b/src/scala/spark/LocalScheduler.scala @@ -0,0 +1,65 @@ +package spark + +import java.util.concurrent._ + +import scala.collection.mutable.Map + +// A simple Scheduler implementation that runs tasks locally in a thread pool. +private class LocalScheduler(threads: Int) extends Scheduler { + var threadPool: ExecutorService = + Executors.newFixedThreadPool(threads, DaemonThreadFactory) + + override def start() {} + + override def waitForRegister() {} + + override def runTasks[T](tasks: Array[Task[T]]): Array[T] = { + val futures = new Array[Future[TaskResult[T]]](tasks.length) + + for (i <- 0 until tasks.length) { + futures(i) = threadPool.submit(new Callable[TaskResult[T]]() { + def call(): TaskResult[T] = { + println("Running task " + i) + try { + // Serialize and deserialize the task so that accumulators are + // changed to thread-local ones; this adds a bit of unnecessary + // overhead but matches how the Nexus Executor works + Accumulators.clear + val bytes = Utils.serialize(tasks(i)) + println("Size of task " + i + " is " + bytes.size + " bytes") + val task = Utils.deserialize[Task[T]]( + bytes, currentThread.getContextClassLoader) + val value = task.run + val accumUpdates = Accumulators.values + println("Finished task " + i) + new TaskResult[T](value, accumUpdates) + } catch { + case e: Exception => { + // TODO: Do something nicer here + System.err.println("Exception in task " + i + ":") + e.printStackTrace() + System.exit(1) + null + } + } + } + }) + } + + val taskResults = futures.map(_.get) + for (result <- taskResults) + Accumulators.add(currentThread, result.accumUpdates) + return taskResults.map(_.value).toArray + } + + override def stop() {} +} + +// A ThreadFactory that creates daemon threads +private object DaemonThreadFactory extends ThreadFactory { + override def newThread(r: Runnable): Thread = { + val t = new Thread(r); + t.setDaemon(true) + return t + } +} diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala new file mode 100644 index 0000000000..a96fca9350 --- /dev/null +++ b/src/scala/spark/NexusScheduler.scala @@ -0,0 +1,258 @@ +package spark + +import java.io.File +import java.util.concurrent.Semaphore + +import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus} +import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver} +import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap} + +// The main Scheduler implementation, which talks to Nexus. Clients are expected +// to first call start(), then submit tasks through the runTasks method. +// +// This implementation is currently a little quick and dirty. The following +// improvements need to be made to it: +// 1) Fault tolerance should be added - if a task fails, just re-run it anywhere. +// 2) Right now, the scheduler uses a linear scan through the tasks to find a +// local one for a given node. It would be faster to have a separate list of +// pending tasks for each node. +// 3) The Callbacks way of organizing things didn't work out too well, so the +// way the scheduler keeps track of the currently active runTasks operation +// can be made cleaner. +private class NexusScheduler( + master: String, frameworkName: String, execArg: Array[Byte]) +extends nexus.Scheduler with spark.Scheduler +{ + // Semaphore used by runTasks to ensure only one thread can be in it + val semaphore = new Semaphore(1) + + // Lock used to wait for scheduler to be registered + var isRegistered = false + val registeredLock = new Object() + + // Trait representing a set of scheduler callbacks + trait Callbacks { + def slotOffer(s: SlaveOffer): Option[TaskDescription] + def taskFinished(t: TaskStatus): Unit + def error(code: Int, message: String): Unit + } + + // Current callback object (may be null) + var callbacks: Callbacks = null + + // Incrementing task ID + var nextTaskId = 0 + + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong + + // Driver for talking to Nexus + var driver: SchedulerDriver = null + + override def start() { + new Thread("Spark scheduler") { + setDaemon(true) + override def run { + val ns = NexusScheduler.this + ns.driver = new NexusSchedulerDriver(ns, master) + ns.driver.run() + } + }.start + } + + override def getFrameworkName(d: SchedulerDriver): String = frameworkName + + override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = + new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) + + override def runTasks[T](tasks: Array[Task[T]]): Array[T] = { + val results = new Array[T](tasks.length) + if (tasks.length == 0) + return results + + val launched = new Array[Boolean](tasks.length) + + val callingThread = currentThread + + var errorHappened = false + var errorCode = 0 + var errorMessage = "" + + // Wait for scheduler to be registered with Nexus + waitForRegister() + + try { + // Acquire the runTasks semaphore + semaphore.acquire() + + val myCallbacks = new Callbacks { + val firstTaskId = nextTaskId + var tasksLaunched = 0 + var tasksFinished = 0 + var lastPreferredLaunchTime = System.currentTimeMillis + + def slotOffer(slot: SlaveOffer): Option[TaskDescription] = { + try { + if (tasksLaunched < tasks.length) { + // TODO: Add a short wait if no task with location pref is found + // TODO: Figure out why a function is needed around this to + // avoid scala.runtime.NonLocalReturnException + def findTask: Option[TaskDescription] = { + var checkPrefVals: Array[Boolean] = Array(true) + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime > LOCALITY_WAIT) + checkPrefVals = Array(true, false) // Allow non-preferred tasks + // TODO: Make desiredCpus and desiredMem configurable + val desiredCpus = 1 + val desiredMem = 750L * 1024L * 1024L + if (slot.getParams.get("cpus").toInt < desiredCpus || + slot.getParams.get("mem").toLong < desiredMem) + return None + for (checkPref <- checkPrefVals; + i <- 0 until tasks.length; + if !launched(i) && (!checkPref || tasks(i).prefers(slot))) + { + val taskId = nextTaskId + nextTaskId += 1 + printf("Starting task %d as TID %d on slave %d: %s (%s)\n", + i, taskId, slot.getSlaveId, slot.getHost, + if(checkPref) "preferred" else "non-preferred") + tasks(i).markStarted(slot) + launched(i) = true + tasksLaunched += 1 + if (checkPref) + lastPreferredLaunchTime = time + val params = new StringMap + params.set("cpus", "" + desiredCpus) + params.set("mem", "" + desiredMem) + val serializedTask = Utils.serialize(tasks(i)) + return Some(new TaskDescription(taskId, slot.getSlaveId, + "task_" + taskId, params, serializedTask)) + } + return None + } + return findTask + } else { + return None + } + } catch { + case e: Exception => { + e.printStackTrace + System.exit(1) + return None + } + } + } + + def taskFinished(status: TaskStatus) { + println("Finished TID " + status.getTaskId) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(status.getTaskId - firstTaskId) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Stop if we've finished all the tasks + tasksFinished += 1 + if (tasksFinished == tasks.length) { + NexusScheduler.this.callbacks = null + NexusScheduler.this.notifyAll() + } + } + + def error(code: Int, message: String) { + // Save the error message + errorHappened = true + errorCode = code + errorMessage = message + // Indicate to caller thread that we're done + NexusScheduler.this.callbacks = null + NexusScheduler.this.notifyAll() + } + } + + this.synchronized { + this.callbacks = myCallbacks + } + driver.reviveOffers(); + this.synchronized { + while (this.callbacks != null) this.wait() + } + } finally { + semaphore.release() + } + + if (errorHappened) + throw new SparkException(errorMessage, errorCode) + else + return results + } + + override def registered(d: SchedulerDriver, frameworkId: Int) { + println("Registered as framework ID " + frameworkId) + registeredLock.synchronized { + isRegistered = true + registeredLock.notifyAll() + } + } + + override def waitForRegister() { + registeredLock.synchronized { + while (!isRegistered) registeredLock.wait() + } + } + + override def resourceOffer( + d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) { + synchronized { + val tasks = new TaskDescriptionVector + if (callbacks != null) { + try { + for (i <- 0 until slots.size.toInt) { + callbacks.slotOffer(slots.get(i)) match { + case Some(task) => tasks.add(task) + case None => {} + } + } + } catch { + case e: Exception => e.printStackTrace + } + } + val params = new StringMap + params.set("timeout", "1") + d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout + } + } + + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + synchronized { + if (callbacks != null && status.getState == TaskState.TASK_FINISHED) { + try { + callbacks.taskFinished(status) + } catch { + case e: Exception => e.printStackTrace + } + } + } + } + + override def error(d: SchedulerDriver, code: Int, message: String) { + synchronized { + if (callbacks != null) { + try { + callbacks.error(code, message) + } catch { + case e: Exception => e.printStackTrace + } + } else { + val msg = "Nexus error: %s (error code: %d)".format(message, code) + System.err.println(msg) + System.exit(1) + } + } + } + + override def stop() { + if (driver != null) + driver.stop() + } +} diff --git a/src/scala/spark/ParallelArray.scala b/src/scala/spark/ParallelArray.scala new file mode 100644 index 0000000000..90cacf47fc --- /dev/null +++ b/src/scala/spark/ParallelArray.scala @@ -0,0 +1,97 @@ +package spark + +abstract class ParallelArray[T](sc: SparkContext) { + def filter(f: T => Boolean): ParallelArray[T] = { + val cleanF = sc.clean(f) + new FilteredParallelArray[T](sc, this, cleanF) + } + + def foreach(f: T => Unit): Unit + + def map[U](f: T => U): Array[U] +} + +private object ParallelArray { + def slice[T](seq: Seq[T], numSlices: Int): Array[Seq[T]] = { + if (numSlices < 1) + throw new IllegalArgumentException("Positive number of slices required") + seq match { + case r: Range.Inclusive => { + val sign = if (r.step < 0) -1 else 1 + slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], + numSlices) + } + case r: Range => { + (0 until numSlices).map(i => { + val start = ((i * r.length.toLong) / numSlices).toInt + val end = (((i+1) * r.length.toLong) / numSlices).toInt + new SerializableRange( + r.start + start * r.step, r.start + end * r.step, r.step) + }).asInstanceOf[Seq[Seq[T]]].toArray + } + case _ => { + val array = seq.toArray // To prevent O(n^2) operations for List etc + (0 until numSlices).map(i => { + val start = ((i * array.length.toLong) / numSlices).toInt + val end = (((i+1) * array.length.toLong) / numSlices).toInt + array.slice(start, end).toArray + }).toArray + } + } + } +} + +private class SimpleParallelArray[T]( + sc: SparkContext, data: Seq[T], numSlices: Int) +extends ParallelArray[T](sc) { + val slices = ParallelArray.slice(data, numSlices) + + def foreach(f: T => Unit) { + val cleanF = sc.clean(f) + var tasks = for (i <- 0 until numSlices) yield + new ForeachRunner(i, slices(i), cleanF) + sc.runTasks[Unit](tasks.toArray) + } + + def map[U](f: T => U): Array[U] = { + val cleanF = sc.clean(f) + var tasks = for (i <- 0 until numSlices) yield + new MapRunner(i, slices(i), cleanF) + return Array.concat(sc.runTasks[Array[U]](tasks.toArray): _*) + } +} + +@serializable +private class ForeachRunner[T](sliceNum: Int, data: Seq[T], f: T => Unit) +extends Function0[Unit] { + def apply() = { + printf("Running slice %d of parallel foreach\n", sliceNum) + data.foreach(f) + } +} + +@serializable +private class MapRunner[T, U](sliceNum: Int, data: Seq[T], f: T => U) +extends Function0[Array[U]] { + def apply(): Array[U] = { + printf("Running slice %d of parallel map\n", sliceNum) + return data.map(f).toArray + } +} + +private class FilteredParallelArray[T]( + sc: SparkContext, array: ParallelArray[T], predicate: T => Boolean) +extends ParallelArray[T](sc) { + val cleanPred = sc.clean(predicate) + + def foreach(f: T => Unit) { + val cleanF = sc.clean(f) + array.foreach(t => if (cleanPred(t)) cleanF(t)) + } + + def map[U](f: T => U): Array[U] = { + val cleanF = sc.clean(f) + throw new UnsupportedOperationException( + "Map is not yet supported on FilteredParallelArray") + } +} diff --git a/src/scala/spark/Scheduler.scala b/src/scala/spark/Scheduler.scala new file mode 100644 index 0000000000..77446d3e78 --- /dev/null +++ b/src/scala/spark/Scheduler.scala @@ -0,0 +1,9 @@ +package spark + +// Scheduler trait, implemented by both NexusScheduler and LocalScheduler. +private trait Scheduler { + def start() + def waitForRegister() + def runTasks[T](tasks: Array[Task[T]]): Array[T] + def stop() +} diff --git a/src/scala/spark/SerializableRange.scala b/src/scala/spark/SerializableRange.scala new file mode 100644 index 0000000000..5d383a40dc --- /dev/null +++ b/src/scala/spark/SerializableRange.scala @@ -0,0 +1,75 @@ +// This is a copy of Scala 2.7.7's Range class, (c) 2006-2009, LAMP/EPFL. +// The only change here is to make it Serializable, because Ranges aren't. +// This won't be needed in Scala 2.8, where Scala's Range becomes Serializable. + +package spark + +@serializable +private class SerializableRange(val start: Int, val end: Int, val step: Int) +extends RandomAccessSeq.Projection[Int] { + if (step == 0) throw new Predef.IllegalArgumentException + + /** Create a new range with the start and end values of this range and + * a new <code>step</code>. + */ + def by(step: Int): Range = new Range(start, end, step) + + override def foreach(f: Int => Unit) { + if (step > 0) { + var i = this.start + val until = if (inInterval(end)) end + 1 else end + + while (i < until) { + f(i) + i += step + } + } else { + var i = this.start + val until = if (inInterval(end)) end - 1 else end + + while (i > until) { + f(i) + i += step + } + } + } + + lazy val length: Int = { + if (start < end && this.step < 0) 0 + else if (start > end && this.step > 0) 0 + else { + val base = if (start < end) end - start + else start - end + assert(base >= 0) + val step = if (this.step < 0) -this.step else this.step + assert(step >= 0) + base / step + last(base, step) + } + } + + protected def last(base: Int, step: Int): Int = + if (base % step != 0) 1 else 0 + + def apply(idx: Int): Int = { + if (idx < 0 || idx >= length) throw new Predef.IndexOutOfBoundsException + start + (step * idx) + } + + /** a <code>Seq.contains</code>, not a <code>Iterator.contains</code>! */ + def contains(x: Int): Boolean = { + inInterval(x) && (((x - start) % step) == 0) + } + + /** Is the argument inside the interval defined by `start' and `end'? + * Returns true if `x' is inside [start, end). + */ + protected def inInterval(x: Int): Boolean = + if (step > 0) + (x >= start && x < end) + else + (x <= start && x > end) + + //def inclusive = new Range.Inclusive(start,end,step) + + override def toString = "SerializableRange(%d, %d, %d)".format(start, end, step) +} diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala new file mode 100644 index 0000000000..4bfbcb6f21 --- /dev/null +++ b/src/scala/spark/SparkContext.scala @@ -0,0 +1,89 @@ +package spark + +import java.io._ +import java.util.UUID + +import scala.collection.mutable.ArrayBuffer + +class SparkContext(master: String, frameworkName: String) { + Cache.initialize() + + def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] = + new SimpleParallelArray[T](this, seq, numSlices) + + def parallelize[T](seq: Seq[T]): ParallelArray[T] = parallelize(seq, 2) + + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + new Accumulator(initialValue, param) + + // TODO: Keep around a weak hash map of values to Cached versions? + def broadcast[T](value: T) = new Cached(value, local) + + def textFile(path: String) = new HdfsTextFile(this, path) + + val LOCAL_REGEX = """local\[([0-9]+)\]""".r + + private var scheduler: Scheduler = master match { + case "local" => new LocalScheduler(1) + case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt) + case _ => { System.loadLibrary("nexus"); + new NexusScheduler(master, frameworkName, createExecArg()) } + } + + private val local = scheduler.isInstanceOf[LocalScheduler] + + scheduler.start() + + private def createExecArg(): Array[Byte] = { + // Our executor arg is an array containing all the spark.* system properties + val props = new ArrayBuffer[(String, String)] + val iter = System.getProperties.entrySet.iterator + while (iter.hasNext) { + val entry = iter.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) + props += (key, value) + } + return Utils.serialize(props.toArray) + } + + def runTasks[T](tasks: Array[() => T]): Array[T] = { + runTaskObjects(tasks.map(f => new FunctionTask(f))) + } + + private[spark] def runTaskObjects[T](tasks: Seq[Task[T]]): Array[T] = { + println("Running " + tasks.length + " tasks in parallel") + val start = System.nanoTime + val result = scheduler.runTasks(tasks.toArray) + println("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s") + return result + } + + def stop() { + scheduler.stop() + scheduler = null + } + + def waitForRegister() { + scheduler.waitForRegister() + } + + // Clean a closure to make it ready to serialized and send to tasks + // (removes unreferenced variables in $outer's, updates REPL variables) + private[spark] def clean[F <: AnyRef](f: F): F = { + ClosureCleaner.clean(f) + return f + } +} + +object SparkContext { + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def add(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def add(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala new file mode 100644 index 0000000000..7257bf7b0c --- /dev/null +++ b/src/scala/spark/SparkException.scala @@ -0,0 +1,7 @@ +package spark + +class SparkException(message: String) extends Exception(message) { + def this(message: String, errorCode: Int) { + this("%s (error code: %d)".format(message, errorCode)) + } +} diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala new file mode 100644 index 0000000000..e559996a37 --- /dev/null +++ b/src/scala/spark/Task.scala @@ -0,0 +1,16 @@ +package spark + +import nexus._ + +@serializable +trait Task[T] { + def run: T + def prefers(slot: SlaveOffer): Boolean = true + def markStarted(slot: SlaveOffer) {} +} + +@serializable +class FunctionTask[T](body: () => T) +extends Task[T] { + def run: T = body() +} diff --git a/src/scala/spark/TaskResult.scala b/src/scala/spark/TaskResult.scala new file mode 100644 index 0000000000..db33c9ff44 --- /dev/null +++ b/src/scala/spark/TaskResult.scala @@ -0,0 +1,9 @@ +package spark + +import scala.collection.mutable.Map + +// Task result. Also contains updates to accumulator variables. +// TODO: Use of distributed cache to return result is a hack to get around +// what seems to be a bug with messages over 60KB in libprocess; fix it +@serializable +private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala new file mode 100644 index 0000000000..52bcb89f00 --- /dev/null +++ b/src/scala/spark/Utils.scala @@ -0,0 +1,28 @@ +package spark + +import java.io._ + +private object Utils { + def serialize[T](o: T): Array[Byte] = { + val bos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(bos) + oos.writeObject(o) + oos.close + return bos.toByteArray + } + + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + return ois.readObject.asInstanceOf[T] + } + + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] + } +} diff --git a/src/scala/spark/repl/ExecutorClassLoader.scala b/src/scala/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000000..7d91b20e79 --- /dev/null +++ b/src/scala/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,86 @@ +package spark.repl + +import java.io.{ByteArrayOutputStream, InputStream} +import java.net.{URI, URL, URLClassLoader} +import java.util.concurrent.{Executors, ExecutorService} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.objectweb.asm._ +import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.Opcodes._ + + +// A ClassLoader that reads classes from a Hadoop FileSystem URL, used to load +// classes defined by the interpreter when the REPL is in use +class ExecutorClassLoader(classDir: String, parent: ClassLoader) +extends ClassLoader(parent) { + val fileSystem = FileSystem.get(new URI(classDir), new Configuration()) + val directory = new URI(classDir).getPath + + override def findClass(name: String): Class[_] = { + try { + //println("repl.ExecutorClassLoader resolving " + name) + val path = new Path(directory, name.replace('.', '/') + ".class") + val bytes = readAndTransformClass(name, fileSystem.open(path)) + return defineClass(name, bytes, 0, bytes.length) + } catch { + case e: Exception => throw new ClassNotFoundException(name, e) + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) + bos.write(bytes, 0, num) + else + done = true + } + return bos.toByteArray + } + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassAdapter(cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "<init>" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + //println("Cleaning constructor of " + className) + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V") + mv.visitVarInsn(ALOAD, 0) // load this + //val classType = className.replace('.', '/') + //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/src/scala/spark/repl/Main.scala b/src/scala/spark/repl/Main.scala new file mode 100644 index 0000000000..f00df5aa58 --- /dev/null +++ b/src/scala/spark/repl/Main.scala @@ -0,0 +1,16 @@ +package spark.repl + +import scala.collection.mutable.Set + +object Main { + private var _interp: SparkInterpreterLoop = null + + def interp = _interp + + private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i } + + def main(args: Array[String]) { + _interp = new SparkInterpreterLoop + _interp.main(args) + } +} diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala new file mode 100644 index 0000000000..2377f0c7d6 --- /dev/null +++ b/src/scala/spark/repl/SparkInterpreter.scala @@ -0,0 +1,1004 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * @author Martin Odersky + */ +// $Id: Interpreter.scala 17013 2009-02-02 11:59:53Z washburn $ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ + +import java.io.{File, IOException, PrintWriter, StringWriter, Writer} +import java.lang.{Class, ClassLoader} +import java.net.{MalformedURLException, URL, URLClassLoader} +import java.util.UUID + +import scala.collection.immutable.ListSet +import scala.collection.mutable +import scala.collection.mutable.{ListBuffer, HashSet, ArrayBuffer} + +//import ast.parser.SyntaxAnalyzer +import io.{PlainFile, VirtualDirectory} +import reporters.{ConsoleReporter, Reporter} +import symtab.Flags +import util.{SourceFile,BatchSourceFile,ClassPath,NameTransformer} +import nsc.{InterpreterResults=>IR} +import scala.tools.nsc.interpreter._ + +/** <p> + * An interpreter for Scala code. + * </p> + * <p> + * The main public entry points are <code>compile()</code>, + * <code>interpret()</code>, and <code>bind()</code>. + * The <code>compile()</code> method loads a + * complete Scala file. The <code>interpret()</code> method executes one + * line of Scala code at the request of the user. The <code>bind()</code> + * method binds an object to a variable that can then be used by later + * interpreted code. + * </p> + * <p> + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * </p> + * <p> + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports a single member named "result". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * </p> + * <p> + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * </p> + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkInterpreter(val settings: Settings, out: PrintWriter) { + import symtab.Names + + /* If the interpreter is running on pre-jvm-1.5 JVM, + it is necessary to force the target setting to jvm-1.4 */ + private val major = System.getProperty("java.class.version").split("\\.")(0) + if (major.toInt < 49) { + this.settings.target.value = "jvm-1.4" + } + + /** directory to save .class files to */ + //val virtualDirectory = new VirtualDirectory("(memory)", None) + val virtualDirectory = { + val tmpDir = new File(System.getProperty("java.io.tmpdir")) + var attempts = 0 + val maxAttempts = 10 + var outputDir: File = null + while (outputDir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory " + + "after " + maxAttempts + " attempts!") + } + try { + outputDir = new File(tmpDir, "spark-" + UUID.randomUUID.toString) + if (outputDir.exists() || !outputDir.mkdirs()) + outputDir = null + } catch { case e: IOException => ; } + } + System.setProperty("spark.repl.classdir", + "file://" + outputDir.getAbsolutePath + "/") + //println("Output dir: " + outputDir) + new PlainFile(outputDir) + } + + /** the compiler to compile expressions with */ + val compiler: scala.tools.nsc.Global = newCompiler(settings, reporter) + + import compiler.Traverser + import compiler.{Tree, TermTree, + ValOrDefDef, ValDef, DefDef, Assign, + ClassDef, ModuleDef, Ident, Select, TypeDef, + Import, MemberDef, DocDef} + import compiler.CompilationUnit + import compiler.{Symbol,Name,Type} + import compiler.nme + import compiler.newTermName + import compiler.newTypeName + import compiler.nme.{INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX} + import Interpreter.string2code + + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = + this(settings, + new NewLinePrintWriter(new ConsoleWriter, true)) + + /** whether to print out result lines */ + private var printResults: Boolean = true + + /** Be quiet. Do not print out the results of each + * submitted command unless an exception is thrown. */ + def beQuiet = { printResults = false } + + /** Temporarily be quiet */ + def beQuietDuring[T](operation: => T): T = { + val wasPrinting = printResults + try { + printResults = false + operation + } finally { + printResults = wasPrinting + } + } + + /** interpreter settings */ + lazy val isettings = new InterpreterSettings + + object reporter extends ConsoleReporter(settings, null, out) { + //override def printMessage(msg: String) { out.println(clean(msg)) } + override def printMessage(msg: String) { out.print(clean(msg) + "\n"); out.flush() } + } + + /** Instantiate a compiler. Subclasses can override this to + * change the compiler class used by this interpreter. */ + protected def newCompiler(settings: Settings, reporter: Reporter) = { + val comp = new scala.tools.nsc.Global(settings, reporter) + comp.genJVM.outputDir = virtualDirectory + comp + } + + + /** the compiler's classpath, as URL's */ + val compilerClasspath: List[URL] = { + val classpathPart = + (ClassPath.expandPath(compiler.settings.classpath.value). + map(s => new File(s).toURL)) + def parseURL(s: String): Option[URL] = + try { Some(new URL(s)) } + catch { case _:MalformedURLException => None } + val codebasePart = (compiler.settings.Xcodebase.value.split(" ")).toList.flatMap(parseURL) + classpathPart ::: codebasePart + } + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + /** class loader used to load compiled code */ + private val classLoader = { + val parent = + if (parentClassLoader == null) + new URLClassLoader(compilerClasspath.toArray) + else + new URLClassLoader(compilerClasspath.toArray, + parentClassLoader) + val virtualDirUrl = new URL("file://" + virtualDirectory.path + "/") + new URLClassLoader(Array(virtualDirUrl), parent) + //new InterpreterClassLoader(Array(virtualDirUrl), parent) + //new AbstractFileClassLoader(virtualDirectory, parent) + } + + /** Set the current Java "context" class loader to this + * interpreter's class loader */ + def setContextClassLoader() { + Thread.currentThread.setContextClassLoader(classLoader) + } + + protected def parentClassLoader: ClassLoader = this.getClass.getClassLoader() + + /** the previous requests this interpreter has processed */ + private val prevRequests = new ArrayBuffer[Request]() + + /** next line number to use */ + private var nextLineNo = 0 + + /** allocate a fresh line name */ + private def newLineName = { + val num = nextLineNo + nextLineNo += 1 + compiler.nme.INTERPRETER_LINE_PREFIX + num + } + + /** next result variable number to use */ + private var nextVarNameNo = 0 + + /** allocate a fresh variable name */ + private def newVarName() = { + val num = nextVarNameNo + nextVarNameNo += 1 + INTERPRETER_VAR_PREFIX + num + } + + /** next internal variable number to use */ + private var nextInternalVarNo = 0 + + /** allocate a fresh internal variable name */ + private def newInternalVarName() = { + val num = nextVarNameNo + nextVarNameNo += 1 + INTERPRETER_SYNTHVAR_PREFIX + num + } + + + /** Check if a name looks like it was generated by newVarName */ + private def isGeneratedVarName(name: String): Boolean = + name.startsWith(INTERPRETER_VAR_PREFIX) && { + val suffix = name.drop(INTERPRETER_VAR_PREFIX.length) + suffix.forall(_.isDigit) + } + + + /** generate a string using a routine that wants to write on a stream */ + private def stringFrom(writer: PrintWriter => Unit): String = { + val stringWriter = new StringWriter() + val stream = new NewLinePrintWriter(stringWriter) + writer(stream) + stream.close + stringWriter.toString + } + + /** Truncate a string if it is longer than settings.maxPrintString */ + private def truncPrintString(str: String): String = { + val maxpr = isettings.maxPrintString + + if (maxpr <= 0) + return str + + if (str.length <= maxpr) + return str + + val trailer = "..." + if (maxpr >= trailer.length+1) + return str.substring(0, maxpr-3) + trailer + + str.substring(0, maxpr) + } + + /** Clean up a string for output */ + private def clean(str: String) = + truncPrintString(Interpreter.stripWrapperGunk(str)) + + /** Indent some code by the width of the scala> prompt. + * This way, compiler error messages read beettr. + */ + def indentCode(code: String) = { + val spaces = " " + + stringFrom(str => + for (line <- code.lines) { + str.print(spaces) + str.print(line + "\n") + str.flush() + }) + } + + implicit def name2string(name: Name) = name.toString + + /** Compute imports that allow definitions from previous + * requests to be visible in a new request. Returns + * three pieces of related code: + * + * 1. An initial code fragment that should go before + * the code of the new request. + * + * 2. A code fragment that should go after the code + * of the new request. + * + * 3. An access path which can be traverested to access + * any bindings inside code wrapped by #1 and #2 . + * + * The argument is a set of Names that need to be imported. + * + * Limitations: This method is not as precise as it could be. + * (1) It does not process wildcard imports to see what exactly + * they import. + * (2) If it imports any names from a request, it imports all + * of them, which is not really necessary. + * (3) It imports multiple same-named implicits, but only the + * last one imported is actually usable. + */ + private def importsCode(wanted: Set[Name]): (String, String, String) = { + /** Narrow down the list of requests from which imports + * should be taken. Removes requests which cannot contribute + * useful imports for the specified set of wanted names. + */ + def reqsToUse: List[(Request,MemberHandler)] = { + /** Loop through a list of MemberHandlers and select + * which ones to keep. 'wanted' is the set of + * names that need to be imported, and + * 'shadowed' is the list of names useless to import + * because a later request will re-import it anyway. + */ + def select(reqs: List[(Request,MemberHandler)], wanted: Set[Name]): + List[(Request,MemberHandler)] = { + reqs match { + case Nil => Nil + + case (req,handler)::rest => + val keepit = + (handler.definesImplicit || + handler.importsWildcard || + handler.importedNames.exists(wanted.contains(_)) || + handler.boundNames.exists(wanted.contains(_))) + + val newWanted = + if (keepit) { + (wanted + ++ handler.usedNames + -- handler.boundNames + -- handler.importedNames) + } else { + wanted + } + + val restToKeep = select(rest, newWanted) + + if(keepit) + (req,handler) :: restToKeep + else + restToKeep + } + } + + val rhpairs = for { + req <- prevRequests.toList.reverse + handler <- req.handlers + } yield (req, handler) + + select(rhpairs, wanted).reverse + } + + val code = new StringBuffer + val trailingLines = new ArrayBuffer[String] + val accessPath = new StringBuffer + val impname = compiler.nme.INTERPRETER_IMPORT_WRAPPER + val currentImps = mutable.Set.empty[Name] + + // add code for a new object to hold some imports + /*def addWrapper() { + code.append("object " + impname + "{\n") + trailingLines.append("}\n") + accessPath.append("." + impname) + currentImps.clear + }*/ + def addWrapper() { + code.append("@serializable class " + impname + "C {\n") + trailingLines.append("}\nval " + impname + " = new " + impname + "C;\n") + accessPath.append("." + impname) + currentImps.clear + } + + addWrapper() + + // loop through previous requests, adding imports + // for each one + for ((req,handler) <- reqsToUse) { + // If the user entered an import, then just use it + + // add an import wrapping level if the import might + // conflict with some other import + if(handler.importsWildcard || + currentImps.exists(handler.importedNames.contains)) + if(!currentImps.isEmpty) + addWrapper() + + if (handler.member.isInstanceOf[Import]) + code.append(handler.member.toString + ";\n") + + // give wildcard imports a import wrapper all to their own + if(handler.importsWildcard) + addWrapper() + else + currentImps ++= handler.importedNames + + // For other requests, import each bound variable. + // import them explicitly instead of with _, so that + // ambiguity errors will not be generated. Also, quote + // the name of the variable, so that we don't need to + // handle quoting keywords separately. + for (imv <- handler.boundNames) { + if (currentImps.contains(imv)) + addWrapper() + code.append("val " + req.objectName + "$VAL = " + req.objectName + ".INSTANCE;\n") + code.append("import ") + code.append(req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n") + // The code below is less likely to pull in bad variables, but prevents use of vars & classes + //code.append("val `" + imv + "` = " + req.objectName + ".INSTANCE" + req.accessPath + ".`" + imv + "`;\n") + currentImps += imv + } + } + + addWrapper() // Add one extra wrapper, to prevent warnings + // in the frequent case of redefining + // the value bound in the last interpreter + // request. + + (code.toString, trailingLines.reverse.mkString, accessPath.toString) + } + + /** Parse a line into a sequence of trees. Returns None if the input + * is incomplete. */ + private def parse(line: String): Option[List[Tree]] = { + var justNeedsMore = false + reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) { + // simple parse: just parse it, nothing else + def simpleParse(code: String): List[Tree] = { + reporter.reset + val unit = + new CompilationUnit( + new BatchSourceFile("<console>", code.toCharArray())) + val scanner = new compiler.syntaxAnalyzer.UnitParser(unit); + val xxx = scanner.templateStatSeq(false); + (xxx._2) + } + val (trees) = simpleParse(line) + if (reporter.hasErrors) { + Some(Nil) // the result did not parse, so stop + } else if (justNeedsMore) { + None + } else { + Some(trees) + } + } + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false othrewise. + */ + def compileSources(sources: List[SourceFile]): Boolean = { + val cr = new compiler.Run + reporter.reset + cr.compileSources(sources) + !reporter.hasErrors + } + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(List(new BatchSourceFile("<script>", code.toCharArray))) + + /** Build a request from the user. <code>trees</code> is <code>line</code> + * after being parsed. + */ + private def buildRequest(trees: List[Tree], line: String, lineName: String): Request = + new Request(line, lineName) + + private def chooseHandler(member: Tree): Option[MemberHandler] = + member match { + case member: DefDef => + Some(new DefHandler(member)) + case member: ValDef => + Some(new ValHandler(member)) + case member@Assign(Ident(_), _) => Some(new AssignHandler(member)) + case member: ModuleDef => Some(new ModuleHandler(member)) + case member: ClassDef => Some(new ClassHandler(member)) + case member: TypeDef => Some(new TypeAliasHandler(member)) + case member: Import => Some(new ImportHandler(member)) + case DocDef(_, documented) => chooseHandler(documented) + case member => Some(new GenericHandler(member)) + } + + /** <p> + * Interpret one line of input. All feedback, including parse errors + * and evaluation results, are printed via the supplied compiler's + * reporter. Values defined are available for future interpreted + * strings. + * </p> + * <p> + * The return value is whether the line was interpreter successfully, + * e.g. that there were no parse errors. + * </p> + * + * @param line ... + * @return ... + */ + def interpret(line: String): IR.Result = { + if (prevRequests.isEmpty) + new compiler.Run // initialize the compiler + + // parse + val trees = parse(indentCode(line)) match { + case None => return IR.Incomplete + case (Some(Nil)) => return IR.Error // parse error or empty input + case Some(trees) => trees + } + + trees match { + case List(_:Assign) => () + + case List(_:TermTree) | List(_:Ident) | List(_:Select) => + // Treat a single bare expression specially. + // This is necessary due to it being hard to modify + // code at a textual level, and it being hard to + // submit an AST to the compiler. + return interpret("val "+newVarName()+" = \n"+line) + + case _ => () + } + + val lineName = newLineName + + // figure out what kind of request + val req = buildRequest(trees, line, lineName) + if (req eq null) return IR.Error // a disallowed statement type + + if (!req.compile) + return IR.Error // an error happened during compilation, e.g. a type error + + val (interpreterResultString, succeeded) = req.loadAndRun + + if (printResults || !succeeded) { + // print the result + out.print(clean(interpreterResultString)) + } + + // book-keeping + if (succeeded) + prevRequests += req + + if (succeeded) IR.Success else IR.Error + } + + /** A counter used for numbering objects created by <code>bind()</code>. */ + private var binderNum = 0 + + /** Bind a specified name to a specified value. The name may + * later be used by expressions passed to interpret. + * + * @param name the variable name to bind + * @param boundType the type of the variable, as a string + * @param value the object value to bind to it + * @return an indication of whether the binding succeeded + */ + def bind(name: String, boundType: String, value: Any): IR.Result = { + val binderName = "binder" + binderNum + binderNum += 1 + + compileString( + "object " + binderName + + "{ var value: " + boundType + " = _; " + + " def set(x: Any) = value=x.asInstanceOf[" + boundType + "]; }") + + val binderObject = + Class.forName(binderName, true, classLoader) + val setterMethod = + (binderObject + .getDeclaredMethods + .toList + .find(meth => meth.getName == "set") + .get) + var argsHolder: Array[Any] = null // this roundabout approach is to try and + // make sure the value is boxed + argsHolder = List(value).toArray + setterMethod.invoke(null, argsHolder.asInstanceOf[Array[AnyRef]]: _*) + + interpret("val " + name + " = " + binderName + ".value") + } + + + /** <p> + * This instance is no longer needed, so release any resources + * it is using. The reporter's output gets flushed. + * </p> + */ + def close() { + reporter.flush() + } + + /** A traverser that finds all mentioned identifiers, i.e. things + * that need to be imported. + * It might return extra names. + */ + private class ImportVarsTraverser(definedVars: List[Name]) extends Traverser { + val importVars = new HashSet[Name]() + + override def traverse(ast: Tree) { + ast match { + case Ident(name) => importVars += name + case _ => super.traverse(ast) + } + } + } + + + /** Class to handle one member among all the members included + * in a single interpreter request. + */ + private sealed abstract class MemberHandler(val member: Tree) { + val usedNames: List[Name] = { + val ivt = new ImportVarsTraverser(boundNames) + ivt.traverseTrees(List(member)) + ivt.importVars.toList + } + def boundNames: List[Name] = Nil + def valAndVarNames: List[Name] = Nil + def defNames: List[Name] = Nil + val importsWildcard = false + val importedNames: Seq[Name] = Nil + val definesImplicit = + member match { + case tree:MemberDef => + tree.mods.hasFlag(symtab.Flags.IMPLICIT) + case _ => false + } + + def extraCodeToEvaluate(req: Request, code: PrintWriter) { } + def resultExtractionCode(req: Request, code: PrintWriter) { } + } + + private class GenericHandler(member: Tree) extends MemberHandler(member) + + private class ValHandler(member: ValDef) extends MemberHandler(member) { + override lazy val boundNames = List(member.name) + override def valAndVarNames = boundNames + + override def resultExtractionCode(req: Request, code: PrintWriter) { + val vname = member.name + if (member.mods.isPublic && + !(isGeneratedVarName(vname) && + req.typeOf(compiler.encode(vname)) == "Unit")) + { + val prettyName = NameTransformer.decode(vname) + code.print(" + \"" + prettyName + ": " + + string2code(req.typeOf(vname)) + + " = \" + " + + " { val tmp = scala.runtime.ScalaRunTime.stringOf(" + + req.fullPath(vname) + + "); " + + " (if(tmp.toSeq.contains('\\n')) \"\\n\" else \"\") + tmp + \"\\n\"} ") + } + } + } + + private class DefHandler(defDef: DefDef) extends MemberHandler(defDef) { + override lazy val boundNames = List(defDef.name) + override def defNames = boundNames + + override def resultExtractionCode(req: Request, code: PrintWriter) { + if (defDef.mods.isPublic) + code.print("+\""+string2code(defDef.name)+": "+ + string2code(req.typeOf(defDef.name))+"\\n\"") + } + } + + private class AssignHandler(member: Assign) extends MemberHandler(member) { + val lhs = member. lhs.asInstanceOf[Ident] // an unfortunate limitation + + val helperName = newTermName(newInternalVarName()) + override val valAndVarNames = List(helperName) + + override def extraCodeToEvaluate(req: Request, code: PrintWriter) { + code.println("val "+helperName+" = "+member.lhs+";") + } + + /** Print out lhs instead of the generated varName */ + override def resultExtractionCode(req: Request, code: PrintWriter) { + code.print(" + \"" + lhs + ": " + + string2code(req.typeOf(compiler.encode(helperName))) + + " = \" + " + + string2code(req.fullPath(helperName)) + + " + \"\\n\"") + } + } + + private class ModuleHandler(module: ModuleDef) extends MemberHandler(module) { + override lazy val boundNames = List(module.name) + + override def resultExtractionCode(req: Request, code: PrintWriter) { + code.println(" + \"defined module " + + string2code(module.name) + + "\\n\"") + } + } + + private class ClassHandler(classdef: ClassDef) + extends MemberHandler(classdef) + { + override lazy val boundNames = + List(classdef.name) ::: + (if (classdef.mods.hasFlag(Flags.CASE)) + List(classdef.name.toTermName) + else + Nil) + + // TODO: MemberDef.keyword does not include "trait"; + // otherwise it could be used here + def keyword: String = + if (classdef.mods.isTrait) "trait" else "class" + + override def resultExtractionCode(req: Request, code: PrintWriter) { + code.print( + " + \"defined " + + keyword + + " " + + string2code(classdef.name) + + "\\n\"") + } + } + + private class TypeAliasHandler(typeDef: TypeDef) + extends MemberHandler(typeDef) + { + override lazy val boundNames = + if (typeDef.mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef)) + List(typeDef.name) + else + Nil + + override def resultExtractionCode(req: Request, code: PrintWriter) { + code.println(" + \"defined type alias " + + string2code(typeDef.name) + "\\n\"") + } + } + + private class ImportHandler(imp: Import) extends MemberHandler(imp) { + override def resultExtractionCode(req: Request, code: PrintWriter) { + code.println("+ \"" + imp.toString + "\\n\"") + } + + /** Whether this import includes a wildcard import */ + override val importsWildcard = + imp.selectors.map(_._1).contains(nme.USCOREkw) + + /** The individual names imported by this statement */ + override val importedNames: Seq[Name] = + for { + val (_,sel) <- imp.selectors + sel != null + sel != nme.USCOREkw + val name <- List(sel.toTypeName, sel.toTermName) + } + yield name + } + + /** One line of code submitted by the user for interpretation */ + private class Request(val line: String, val lineName: String) { + val trees = parse(line) match { + case Some(ts) => ts + case None => Nil + } + + /** name to use for the object that will compute "line" */ + def objectName = lineName + compiler.nme.INTERPRETER_WRAPPER_SUFFIX + + /** name of the object that retrieves the result from the above object */ + def resultObjectName = "RequestResult$" + objectName + + val handlers: List[MemberHandler] = trees.flatMap(chooseHandler(_)) + + /** all (public) names defined by these statements */ + val boundNames = (ListSet() ++ handlers.flatMap(_.boundNames)).toList + + /** list of names used by this expression */ + val usedNames: List[Name] = handlers.flatMap(_.usedNames) + + def myImportsCode = importsCode(Set.empty ++ usedNames) + + /** Code to append to objectName to access anything that + * the request binds. */ + val accessPath = myImportsCode._3 + + + /** Code to access a variable with the specified name */ + def fullPath(vname: String): String = + objectName + ".INSTANCE" + accessPath + ".`" + vname + "`" + + /** Code to access a variable with the specified name */ + def fullPath(vname: Name): String = fullPath(vname.toString) + + /** the line of code to compute */ + def toCompute = line + + /** generate the source code for the object that computes this request */ + def objectSourceCode: String = { + val src = stringFrom { code => + // header for the wrapper object + code.println("@serializable class " + objectName + " {") + + val (importsPreamble, importsTrailer, _) = myImportsCode + + code.print(importsPreamble) + + code.println(indentCode(toCompute)) + + handlers.foreach(_.extraCodeToEvaluate(this,code)) + + code.println(importsTrailer) + + //end the wrapper object + code.println(";}") + + //create an object + code.println("object " + objectName + " {") + code.println(" val INSTANCE = new " + objectName + "();") + code.println("}") + } + //println(src) + src + } + + /** Types of variables defined by this request. They are computed + after compilation of the main object */ + var typeOf: Map[Name, String] = _ + + /** generate source code for the object that retrieves the result + from objectSourceCode */ + def resultObjectSourceCode: String = + stringFrom(code => { + code.println("object " + resultObjectName) + code.println("{ val result: String = {") + code.println(objectName + ".INSTANCE" + accessPath + ";") // evaluate the object, to make sure its constructor is run + code.print("(\"\"") // print an initial empty string, so later code can + // uniformly be: + morestuff + handlers.foreach(_.resultExtractionCode(this, code)) + code.println("\n)}") + code.println(";}") + }) + + + /** Compile the object file. Returns whether the compilation succeeded. + * If all goes well, the "types" map is computed. */ + def compile(): Boolean = { + reporter.reset // without this, error counting is not correct, + // and the interpreter sometimes overlooks compile failures! + + // compile the main object + val objRun = new compiler.Run() + //println("source: "+objectSourceCode) //DEBUG + objRun.compileSources( + List(new BatchSourceFile("<console>", objectSourceCode.toCharArray)) + ) + if (reporter.hasErrors) return false + + + // extract and remember types + typeOf = findTypes(objRun) + + // compile the result-extraction object + new compiler.Run().compileSources( + List(new BatchSourceFile("<console>", resultObjectSourceCode.toCharArray)) + ) + + // success + !reporter.hasErrors + } + + /** Dig the types of all bound variables out of the compiler run. + * + * @param objRun ... + * @return ... + */ + def findTypes(objRun: compiler.Run): Map[Name, String] = { + def valAndVarNames = handlers.flatMap(_.valAndVarNames) + def defNames = handlers.flatMap(_.defNames) + + def getTypes(names: List[Name], nameMap: Name=>Name): Map[Name, String] = { + /** the outermost wrapper object */ + val outerResObjSym: Symbol = + compiler.definitions.getMember(compiler.definitions.EmptyPackage, + newTermName(objectName).toTypeName) // MATEI: added toTypeName + + /** the innermost object inside the wrapper, found by + * following accessPath into the outer one. */ + val resObjSym = + (accessPath.split("\\.")).foldLeft(outerResObjSym)((sym,name) => + if(name == "") sym else + compiler.atPhase(objRun.typerPhase.next) { + sym.info.member(newTermName(name)) }) + + names.foldLeft(Map.empty[Name,String])((map, name) => { + val rawType = + compiler.atPhase(objRun.typerPhase.next) { + resObjSym.info.member(name).tpe + } + + // the types are all =>T; remove the => + val cleanedType= rawType match { + case compiler.PolyType(Nil, rt) => rt + case rawType => rawType + } + + map + (name -> compiler.atPhase(objRun.typerPhase.next) { cleanedType.toString }) + }) + } + + val names1 = getTypes(valAndVarNames, n => compiler.nme.getterToLocal(n)) + val names2 = getTypes(defNames, identity) + names1 ++ names2 + } + + /** load and run the code using reflection */ + def loadAndRun: (String, Boolean) = { + val interpreterResultObject: Class[_] = + Class.forName(resultObjectName, true, classLoader) + val resultValMethod: java.lang.reflect.Method = + interpreterResultObject.getMethod("result") + try { + (resultValMethod.invoke(interpreterResultObject).toString(), + true) + } catch { + case e => + def caus(e: Throwable): Throwable = + if (e.getCause eq null) e else caus(e.getCause) + val orig = caus(e) + (stringFrom(str => orig.printStackTrace(str)), false) + } + } + } +} + +/** Utility methods for the Interpreter. */ +object Interpreter { + /** Delete a directory tree recursively. Use with care! + */ + def deleteRecursively(path: File) { + path match { + case _ if !path.exists => + () + case _ if path.isDirectory => + for (p <- path.listFiles) + deleteRecursively(p) + path.delete + case _ => + path.delete + } + } + + /** Heuristically strip interpreter wrapper prefixes + * from an interpreter output string. + */ + def stripWrapperGunk(str: String): String = { + //val wrapregex = "(line[0-9]+\\$object[$.])?(\\$iw[$.])*" + //str.replaceAll(wrapregex, "") + str + } + + /** Convert a string into code that can recreate the string. + * This requires replacing all special characters by escape + * codes. It does not add the surrounding " marks. */ + def string2code(str: String): String = { + /** Convert a character to a backslash-u escape */ + def char2uescape(c: Char): String = { + var rest = c.toInt + val buf = new StringBuilder + for (i <- 1 to 4) { + buf ++= (rest % 16).toHexString + rest = rest / 16 + } + "\\" + "u" + buf.toString.reverse + } + + + val res = new StringBuilder + for (c <- str) { + if ("'\"\\" contains c) { + res += '\\' + res += c + } else if (!c.isControl) { + res += c + } else { + res ++= char2uescape(c) + } + } + res.toString + } +} diff --git a/src/scala/spark/repl/SparkInterpreterLoop.scala b/src/scala/spark/repl/SparkInterpreterLoop.scala new file mode 100644 index 0000000000..4aab60fd11 --- /dev/null +++ b/src/scala/spark/repl/SparkInterpreterLoop.scala @@ -0,0 +1,366 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * @author Alexander Spoon + */ +// $Id: InterpreterLoop.scala 16881 2009-01-09 16:28:11Z cunei $ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ + +import java.io.{BufferedReader, File, FileReader, PrintWriter} +import java.io.IOException +import java.lang.{ClassLoader, System} + +import scala.tools.nsc.{InterpreterResults => IR} +import scala.tools.nsc.interpreter._ + +import spark.SparkContext + +/** The + * <a href="http://scala-lang.org/" target="_top">Scala</a> + * interactive shell. It provides a read-eval-print loop around + * the Interpreter class. + * After instantiation, clients should call the <code>main()</code> method. + * + * <p>If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, + master: Option[String]) +{ + def this(in0: BufferedReader, out: PrintWriter, master: String) = + this(Some(in0), out, Some(master)) + + def this(in0: BufferedReader, out: PrintWriter) = + this(Some(in0), out, None) + + def this() = this(None, new PrintWriter(Console.out), None) + + /** The input stream from which interpreter commands come */ + var in: InteractiveReader = _ //set by main() + + /** The context class loader at the time this object was created */ + protected val originalClassLoader = + Thread.currentThread.getContextClassLoader + + var settings: Settings = _ // set by main() + var interpreter: SparkInterpreter = null // set by createInterpreter() + def isettings = interpreter.isettings + + /** A reverse list of commands to replay if the user + * requests a :replay */ + var replayCommandsRev: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandsRev.reverse + + /** Record a command for replay should the user requset a :replay */ + def addReplay(cmd: String) = + replayCommandsRev = cmd :: replayCommandsRev + + /** Close the interpreter, if there is one, and set + * interpreter to <code>null</code>. */ + def closeInterpreter() { + if (interpreter ne null) { + interpreter.close + interpreter = null + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + + /** Create a new interpreter. Close the old one, if there + * is one. */ + def createInterpreter() { + //closeInterpreter() + + interpreter = new SparkInterpreter(settings, out) { + override protected def parentClassLoader = + classOf[SparkInterpreterLoop].getClassLoader + } + interpreter.setContextClassLoader() + } + + /** Bind the settings so that evaluated code can modiy them */ + def bindSettings() { + interpreter.beQuietDuring { + interpreter.compileString(InterpreterSettings.sourceCodeForClass) + + interpreter.bind( + "settings", + "scala.tools.nsc.InterpreterSettings", + isettings) + } + } + + + /** print a friendly help message */ + def printHelp { + //printWelcome + out.println("This is Scala " + Properties.versionString + " (" + + System.getProperty("java.vm.name") + ", Java " + System.getProperty("java.version") + ")." ) + out.println("Type in expressions to have them evaluated.") + out.println("Type :load followed by a filename to load a Scala file.") + //out.println("Type :replay to reset execution and replay all previous commands.") + out.println("Type :quit to exit the interpreter.") + } + + /** Print a welcome message */ + def printWelcome() { + out.println("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version 0.0 + /_/ +""") + + out.println("Using Scala " + Properties.versionString + " (" + + System.getProperty("java.vm.name") + ", Java " + + System.getProperty("java.version") + ")." ) + out.flush() + } + + def createSparkContext(): SparkContext = { + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + new SparkContext(master, "Spark REPL") + } + + /** Prompt to print when awaiting input */ + val prompt = Properties.shellPromptString + + /** The main read-eval-print loop for the interpreter. It calls + * <code>command()</code> for each line of input, and stops when + * <code>command()</code> returns <code>false</code>. + */ + def repl() { + out.println("Intializing...") + out.flush() + interpreter.beQuietDuring { + command(""" + spark.repl.Main.interp.out.println("Registering with Nexus..."); + @transient val sc = spark.repl.Main.interp.createSparkContext(); + sc.waitForRegister(); + spark.repl.Main.interp.out.println("Spark context available as sc.") + """) + command("import spark.SparkContext._"); + } + out.println("Type in expressions to have them evaluated.") + out.println("Type :help for more information.") + out.flush() + + var first = true + while (true) { + out.flush() + + val line = + if (first) { + /* For some reason, the first interpreted command always takes + * a second or two. So, wait until the welcome message + * has been printed before calling bindSettings. That way, + * the user can read the welcome message while this + * command executes. + */ + val futLine = scala.concurrent.ops.future(in.readLine(prompt)) + + bindSettings() + first = false + + futLine() + } else { + in.readLine(prompt) + } + + if (line eq null) + return () // assumes null means EOF + + val (keepGoing, finalLineMaybe) = command(line) + + if (!keepGoing) + return + + finalLineMaybe match { + case Some(finalLine) => addReplay(finalLine) + case None => () + } + } + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(filename: String) { + val fileIn = try { + new FileReader(filename) + } catch { + case _:IOException => + out.println("Error opening file: " + filename) + return + } + val oldIn = in + val oldReplay = replayCommandsRev + try { + val inFile = new BufferedReader(fileIn) + in = new SimpleReader(inFile, out, false) + out.println("Loading " + filename + "...") + out.flush + repl + } finally { + in = oldIn + replayCommandsRev = oldReplay + fileIn.close + } + } + + /** create a new interpreter and replay all commands so far */ + def replay() { + closeInterpreter() + createInterpreter() + for (cmd <- replayCommands) { + out.println("Replaying: " + cmd) + out.flush() // because maybe cmd will have its own output + command(cmd) + out.println + } + } + + /** Run one command submitted by the user. Three values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): (Boolean, Option[String]) = { + def withFile(command: String)(action: String => Unit) { + val spaceIdx = command.indexOf(' ') + if (spaceIdx <= 0) { + out.println("That command requires a filename to be specified.") + return () + } + val filename = command.substring(spaceIdx).trim + if (! new File(filename).exists) { + out.println("That file does not exist") + return () + } + action(filename) + } + + val helpRegexp = ":h(e(l(p)?)?)?" + val quitRegexp = ":q(u(i(t)?)?)?" + val loadRegexp = ":l(o(a(d)?)?)?.*" + //val replayRegexp = ":r(e(p(l(a(y)?)?)?)?)?.*" + + var shouldReplay: Option[String] = None + + if (line.matches(helpRegexp)) + printHelp + else if (line.matches(quitRegexp)) + return (false, None) + else if (line.matches(loadRegexp)) { + withFile(line)(f => { + interpretAllFrom(f) + shouldReplay = Some(line) + }) + } + //else if (line matches replayRegexp) + // replay + else if (line startsWith ":") + out.println("Unknown command. Type :help for help.") + else + shouldReplay = interpretStartingWith(line) + + (true, shouldReplay) + } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + interpreter.interpret(code) match { + case IR.Success => Some(code) + case IR.Error => None + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + out.println("You typed two blank lines. Starting a new command.") + None + } else { + val nextLine = in.readLine(" | ") + if (nextLine == null) + None // end of file + else + interpretStartingWith(code + "\n" + nextLine) + } + } + } + + def loadFiles(settings: Settings) { + settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + replayCommandsRev = cmd :: replayCommandsRev + out.println() + } + case _ => + } + } + + def main(settings: Settings) { + this.settings = settings + + in = + in0 match { + case Some(in0) => + new SimpleReader(in0, out, true) + + case None => + val emacsShell = System.getProperty("env.emacs", "") != "" + //println("emacsShell="+emacsShell) //debug + if (settings.Xnojline.value || emacsShell) + new SimpleReader() + else + InteractiveReader.createDefault() + } + + createInterpreter() + + loadFiles(settings) + + try { + if (interpreter.reporter.hasErrors) { + return // it is broken on startup; go ahead and exit + } + printWelcome() + repl() + } finally { + closeInterpreter() + } + } + + /** process command-line arguments and do as they request */ + def main(args: Array[String]) { + def error1(msg: String) { out.println("scala: " + msg) } + val command = new InterpreterCommand(List.fromArray(args), error1) + + if (!command.ok || command.settings.help.value || command.settings.Xhelp.value) { + // either the command line is wrong, or the user + // explicitly requested a help listing + if (command.settings.help.value) out.println(command.usageMsg) + if (command.settings.Xhelp.value) out.println(command.xusageMsg) + out.flush + } + else + main(command.settings) + } +} diff --git a/src/scala/ubiquifs/Header.scala b/src/scala/ubiquifs/Header.scala new file mode 100644 index 0000000000..bdca83a2d5 --- /dev/null +++ b/src/scala/ubiquifs/Header.scala @@ -0,0 +1,21 @@ +package ubiquifs + +import java.io.{DataInputStream, DataOutputStream} + +object RequestType { + val READ = 0 + val WRITE = 1 +} + +class Header(val requestType: Int, val path: String) { + def write(out: DataOutputStream) { + out.write(requestType) + out.writeUTF(path) + } +} + +object Header { + def read(in: DataInputStream): Header = { + new Header(in.read(), in.readUTF()) + } +} diff --git a/src/scala/ubiquifs/Master.scala b/src/scala/ubiquifs/Master.scala new file mode 100644 index 0000000000..6854acd6a5 --- /dev/null +++ b/src/scala/ubiquifs/Master.scala @@ -0,0 +1,49 @@ +package ubiquifs + +import scala.actors.Actor +import scala.actors.Actor._ +import scala.actors.remote.RemoteActor +import scala.actors.remote.RemoteActor._ +import scala.actors.remote.Node +import scala.collection.mutable.{ArrayBuffer, Map, Set} + +class Master(port: Int) extends Actor { + case class SlaveInfo(host: String, port: Int) + + val files = Set[String]() + val slaves = new ArrayBuffer[SlaveInfo]() + + def act() { + alive(port) + register('UbiquiFS, self) + println("Created UbiquiFS Master on port " + port) + + loop { + react { + case RegisterSlave(host, port) => + slaves += SlaveInfo(host, port) + sender ! RegisterSucceeded() + + case Create(path) => + if (files.contains(path)) { + sender ! CreateFailed("File already exists") + } else if (slaves.isEmpty) { + sender ! CreateFailed("No slaves registered") + } else { + files += path + sender ! CreateSucceeded(slaves(0).host, slaves(0).port) + } + + case m: Any => + println("Unknown message: " + m) + } + } + } +} + +object MasterMain { + def main(args: Array[String]) { + val port = args(0).toInt + new Master(port).start() + } +} diff --git a/src/scala/ubiquifs/Message.scala b/src/scala/ubiquifs/Message.scala new file mode 100644 index 0000000000..153542f8de --- /dev/null +++ b/src/scala/ubiquifs/Message.scala @@ -0,0 +1,14 @@ +package ubiquifs + +sealed case class Message() + +case class RegisterSlave(host: String, port: Int) extends Message +case class RegisterSucceeded() extends Message + +case class Create(path: String) extends Message +case class CreateSucceeded(host: String, port: Int) extends Message +case class CreateFailed(message: String) extends Message + +case class Read(path: String) extends Message +case class ReadSucceeded(host: String, port: Int) extends Message +case class ReadFailed(message: String) extends Message diff --git a/src/scala/ubiquifs/Slave.scala b/src/scala/ubiquifs/Slave.scala new file mode 100644 index 0000000000..328b73c828 --- /dev/null +++ b/src/scala/ubiquifs/Slave.scala @@ -0,0 +1,141 @@ +package ubiquifs + +import java.io.{DataInputStream, DataOutputStream, IOException} +import java.net.{InetAddress, Socket, ServerSocket} +import java.util.concurrent.locks.ReentrantLock + +import scala.actors.Actor +import scala.actors.Actor._ +import scala.actors.remote.RemoteActor +import scala.actors.remote.RemoteActor._ +import scala.actors.remote.Node +import scala.collection.mutable.{ArrayBuffer, Map, Set} + +class Slave(myPort: Int, master: String) extends Thread("UbiquiFS slave") { + val CHUNK_SIZE = 1024 * 1024 + + val buffers = Map[String, Buffer]() + + override def run() { + // Create server socket + val socket = new ServerSocket(myPort) + + // Register with master + val (masterHost, masterPort) = Utils.parseHostPort(master) + val masterActor = select(Node(masterHost, masterPort), 'UbiquiFS) + val myHost = InetAddress.getLocalHost.getHostName + val reply = masterActor !? RegisterSlave(myHost, myPort) + println("Registered with master, reply = " + reply) + + while (true) { + val conn = socket.accept() + new ConnectionHandler(conn).start() + } + } + + class ConnectionHandler(conn: Socket) extends Thread("ConnectionHandler") { + try { + val in = new DataInputStream(conn.getInputStream) + val out = new DataOutputStream(conn.getOutputStream) + val header = Header.read(in) + header.requestType match { + case RequestType.READ => + performRead(header.path, out) + case RequestType.WRITE => + performWrite(header.path, in) + case other => + throw new IOException("Invalid header type " + other) + } + println("hi") + } catch { + case e: Exception => e.printStackTrace() + } finally { + conn.close() + } + } + + def performWrite(path: String, in: DataInputStream) { + var buffer = new Buffer() + synchronized { + if (buffers.contains(path)) + throw new IllegalArgumentException("Path " + path + " already exists") + buffers(path) = buffer + } + var chunk = new Array[Byte](CHUNK_SIZE) + var pos = 0 + while (true) { + var numRead = in.read(chunk, pos, chunk.size - pos) + if (numRead == -1) { + buffer.addChunk(chunk.subArray(0, pos), true) + return + } else { + pos += numRead + if (pos == chunk.size) { + buffer.addChunk(chunk, false) + chunk = new Array[Byte](CHUNK_SIZE) + pos = 0 + } + } + } + // TODO: launch a thread to write the data to disk, and when this finishes, + // remove the hard reference to buffer + } + + def performRead(path: String, out: DataOutputStream) { + var buffer: Buffer = null + synchronized { + if (!buffers.contains(path)) + throw new IllegalArgumentException("Path " + path + " doesn't exist") + buffer = buffers(path) + } + for (chunk <- buffer.iterator) { + out.write(chunk, 0, chunk.size) + } + } + + class Buffer { + val chunks = new ArrayBuffer[Array[Byte]] + var finished = false + val mutex = new ReentrantLock + val chunksAvailable = mutex.newCondition() + + def addChunk(chunk: Array[Byte], finish: Boolean) { + mutex.lock() + chunks += chunk + finished = finish + chunksAvailable.signalAll() + mutex.unlock() + } + + def iterator = new Iterator[Array[Byte]] { + var index = 0 + + def hasNext: Boolean = { + mutex.lock() + while (index >= chunks.size && !finished) + chunksAvailable.await() + val ret = (index < chunks.size) + mutex.unlock() + return ret + } + + def next: Array[Byte] = { + mutex.lock() + if (!hasNext) + throw new NoSuchElementException("End of file") + val ret = chunks(index) // hasNext ensures we advance past index + index += 1 + mutex.unlock() + return ret + } + } + } +} + +object SlaveMain { + def main(args: Array[String]) { + val port = args(0).toInt + val master = args(1) + new Slave(port, master).start() + } +} diff --git a/src/scala/ubiquifs/UbiquiFS.scala b/src/scala/ubiquifs/UbiquiFS.scala new file mode 100644 index 0000000000..9ce0fd4f44 --- /dev/null +++ b/src/scala/ubiquifs/UbiquiFS.scala @@ -0,0 +1,11 @@ +package ubiquifs + +import java.io.{InputStream, OutputStream} + +class UbiquiFS(master: String) { + private val (masterHost, masterPort) = Utils.parseHostPort(master) + + def create(path: String): OutputStream = null + + def open(path: String): InputStream = null +} diff --git a/src/scala/ubiquifs/Utils.scala b/src/scala/ubiquifs/Utils.scala new file mode 100644 index 0000000000..d6fd3f0181 --- /dev/null +++ b/src/scala/ubiquifs/Utils.scala @@ -0,0 +1,12 @@ +package ubiquifs + +private[ubiquifs] object Utils { + private val HOST_PORT_RE = "([a-zA-Z0-9.-]+):([0-9]+)".r + + def parseHostPort(string: String): (String, Int) = { + string match { + case HOST_PORT_RE(host, port) => (host, port.toInt) + case _ => throw new IllegalArgumentException(string) + } + } +} diff --git a/src/test/spark/ParallelArraySplitSuite.scala b/src/test/spark/ParallelArraySplitSuite.scala new file mode 100644 index 0000000000..a1787bd6dd --- /dev/null +++ b/src/test/spark/ParallelArraySplitSuite.scala @@ -0,0 +1,161 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +class ParallelArraySplitSuite extends FunSuite with Checkers { + test("one element per slice") { + val data = Array(1, 2, 3) + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1") + assert(slices(1).mkString(",") === "2") + assert(slices(2).mkString(",") === "3") + } + + test("one slice") { + val data = Array(1, 2, 3) + val slices = ParallelArray.slice(data, 1) + assert(slices.size === 1) + assert(slices(0).mkString(",") === "1,2,3") + } + + test("equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9") + } + + test("non-equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9,10") + } + + test("splitting exclusive range") { + val data = 0 until 100 + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 65).mkString(",")) + assert(slices(2).mkString(",") === (66 to 99).mkString(",")) + } + + test("splitting inclusive range") { + val data = 0 to 100 + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 66).mkString(",")) + assert(slices(2).mkString(",") === (67 to 100).mkString(",")) + } + + test("empty data") { + val data = new Array[Int](0) + val slices = ParallelArray.slice(data, 5) + assert(slices.size === 5) + for (slice <- slices) assert(slice.size === 0) + } + + test("zero slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelArray.slice(data, 0) } + } + + test("negative number of slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelArray.slice(data, -5) } + } + + test("exclusive ranges sliced into ranges") { + val data = 1 until 100 + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[SerializableRange])) + } + + test("inclusive ranges sliced into ranges") { + val data = 1 to 100 + val slices = ParallelArray.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[SerializableRange])) + } + + test("large ranges don't overflow") { + val N = 100 * 1000 * 1000 + val data = 0 until N + val slices = ParallelArray.slice(data, 40) + assert(slices.size === 40) + for (i <- 0 until 40) { + assert(slices(i).isInstanceOf[SerializableRange]) + val range = slices(i).asInstanceOf[SerializableRange] + assert(range.start === i * (N / 40), "slice " + i + " start") + assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.step === 1, "slice " + i + " step") + } + } + + test("random array tests") { + val gen = for { + d <- arbitrary[List[Int]] + n <- Gen.choose(1, 100) + } yield (d, n) + val prop = forAll(gen) { + (tuple: (List[Int], Int)) => + val d = tuple._1 + val n = tuple._2 + val slices = ParallelArray.slice(d, n) + ("n slices" |: slices.size == n) && + ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random exclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a until b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelArray.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[SerializableRange])) && + ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random inclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a to b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelArray.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[SerializableRange])) && + ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } +} diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala new file mode 100644 index 0000000000..d71fe20a94 --- /dev/null +++ b/src/test/spark/repl/ReplSuite.scala @@ -0,0 +1,124 @@ +package spark.repl + +import java.io._ + +import org.scalatest.FunSuite + +class ReplSuite extends FunSuite { + def runInterpreter(master: String, input: String): String = { + val in = new BufferedReader(new StringReader(input + "\n")) + val out = new StringWriter() + val interp = new SparkInterpreterLoop(in, new PrintWriter(out), master) + spark.repl.Main.interp = interp + interp.main(new Array[String](0)) + spark.repl.Main.interp = null + return out.toString + } + + def assertContains(message: String, output: String) { + assert(output contains message, + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + assert(!(output contains message), + "Interpreter output contained '" + message + "':\n" + output) + } + + test ("simple foreach with accumulator") { + val output = runInterpreter("local", """ + val accum = sc.accumulator(0) + sc.parallelize(1 to 10).foreach(x => accum += x) + accum.value + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 55", output) + } + + test ("external vars") { + val output = runInterpreter("local", """ + var v = 7 + sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res2: Int = 100", output) + } + + test ("external classes") { + val output = runInterpreter("local", """ + class C { + def foo = 5 + } + sc.parallelize(1 to 10).map(x => (new C).foo).toArray.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 50", output) + } + + test ("external functions") { + val output = runInterpreter("local", """ + def double(x: Int) = x + x + sc.parallelize(1 to 10).map(x => double(x)).toArray.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 110", output) + } + + test ("external functions that access vars") { + val output = runInterpreter("local", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res2: Int = 100", output) + } + + test ("cached vars") { + // Test that the value that a cached var had when it was created is used, + // even if that cached var is then modified in the driver program + val output = runInterpreter("local", """ + var array = new Array[Int](5) + val cachedArray = sc.cache(array) + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray + array(0) = 5 + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + test ("running on Nexus") { + val output = runInterpreter("localquiet", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + var array = new Array[Int](5) + val cachedArray = sc.cache(array) + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray + array(0) = 5 + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res2: Int = 100", output) + assertContains("res3: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res5: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } +} |