aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2010-03-29 16:17:55 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2010-03-29 16:17:55 -0700
commitdf29d0ea4c8b7137fdd1844219c7d489e3b0d9c9 (patch)
tree3f925c0d109b789ce845762a9e09d24329749eb8 /src
downloadspark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.tar.gz
spark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.tar.bz2
spark-df29d0ea4c8b7137fdd1844219c7d489e3b0d9c9.zip
Initial commit
Diffstat (limited to 'src')
-rw-r--r--src/examples/CpuHog.scala24
-rw-r--r--src/examples/HdfsTest.scala16
-rw-r--r--src/examples/LocalALS.scala118
-rw-r--r--src/examples/LocalFileLR.scala36
-rw-r--r--src/examples/LocalLR.scala41
-rw-r--r--src/examples/LocalPi.scala14
-rw-r--r--src/examples/SleepJob.scala19
-rw-r--r--src/examples/SparkALS.scala138
-rw-r--r--src/examples/SparkHdfsLR.scala50
-rw-r--r--src/examples/SparkLR.scala48
-rw-r--r--src/examples/SparkPi.scala20
-rw-r--r--src/examples/Vector.scala63
-rw-r--r--src/java/spark/compress/lzf/LZF.java27
-rw-r--r--src/java/spark/compress/lzf/LZFInputStream.java180
-rw-r--r--src/java/spark/compress/lzf/LZFOutputStream.java85
-rw-r--r--src/native/Makefile30
-rw-r--r--src/native/spark_compress_lzf_LZF.c90
-rw-r--r--src/scala/spark/Accumulators.scala71
-rw-r--r--src/scala/spark/Cached.scala110
-rw-r--r--src/scala/spark/ClosureCleaner.scala157
-rw-r--r--src/scala/spark/Executor.scala70
-rw-r--r--src/scala/spark/HdfsFile.scala277
-rw-r--r--src/scala/spark/LocalScheduler.scala65
-rw-r--r--src/scala/spark/NexusScheduler.scala258
-rw-r--r--src/scala/spark/ParallelArray.scala97
-rw-r--r--src/scala/spark/Scheduler.scala9
-rw-r--r--src/scala/spark/SerializableRange.scala75
-rw-r--r--src/scala/spark/SparkContext.scala89
-rw-r--r--src/scala/spark/SparkException.scala7
-rw-r--r--src/scala/spark/Task.scala16
-rw-r--r--src/scala/spark/TaskResult.scala9
-rw-r--r--src/scala/spark/Utils.scala28
-rw-r--r--src/scala/spark/repl/ExecutorClassLoader.scala86
-rw-r--r--src/scala/spark/repl/Main.scala16
-rw-r--r--src/scala/spark/repl/SparkInterpreter.scala1004
-rw-r--r--src/scala/spark/repl/SparkInterpreterLoop.scala366
-rw-r--r--src/scala/ubiquifs/Header.scala21
-rw-r--r--src/scala/ubiquifs/Master.scala49
-rw-r--r--src/scala/ubiquifs/Message.scala14
-rw-r--r--src/scala/ubiquifs/Slave.scala141
-rw-r--r--src/scala/ubiquifs/UbiquiFS.scala11
-rw-r--r--src/scala/ubiquifs/Utils.scala12
-rw-r--r--src/test/spark/ParallelArraySplitSuite.scala161
-rw-r--r--src/test/spark/repl/ReplSuite.scala124
44 files changed, 4342 insertions, 0 deletions
diff --git a/src/examples/CpuHog.scala b/src/examples/CpuHog.scala
new file mode 100644
index 0000000000..f37c6f7824
--- /dev/null
+++ b/src/examples/CpuHog.scala
@@ -0,0 +1,24 @@
+import spark._
+
+object CpuHog {
+ def main(args: Array[String]) {
+ if (args.length != 3) {
+ System.err.println("Usage: CpuHog <master> <tasks> <threads_per_task>");
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "CPU hog")
+ val tasks = args(1).toInt
+ val threads = args(2).toInt
+ def task {
+ for (i <- 0 until threads-1) {
+ new Thread() {
+ override def run {
+ while(true) {}
+ }
+ }.start()
+ }
+ while(true) {}
+ }
+ sc.runTasks(Array.make(tasks, () => task))
+ }
+}
diff --git a/src/examples/HdfsTest.scala b/src/examples/HdfsTest.scala
new file mode 100644
index 0000000000..e678154aab
--- /dev/null
+++ b/src/examples/HdfsTest.scala
@@ -0,0 +1,16 @@
+import spark._
+
+object HdfsTest {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "HdfsTest")
+ val file = sc.textFile(args(1))
+ val mapped = file.map(s => s.length).cache()
+ for (iter <- 1 to 10) {
+ val start = System.currentTimeMillis()
+ for (x <- mapped) { x + 2 }
+ // println("Processing: " + x)
+ val end = System.currentTimeMillis()
+ println("Iteration " + iter + " took " + (end-start) + " ms")
+ }
+ }
+}
diff --git a/src/examples/LocalALS.scala b/src/examples/LocalALS.scala
new file mode 100644
index 0000000000..17d67b522b
--- /dev/null
+++ b/src/examples/LocalALS.scala
@@ -0,0 +1,118 @@
+import java.util.Random
+import cern.jet.math._
+import cern.colt.matrix._
+import cern.colt.matrix.linalg._
+
+object LocalALS {
+ // Parameters set through command line arguments
+ var M = 0 // Number of movies
+ var U = 0 // Number of users
+ var F = 0 // Number of features
+ var ITERATIONS = 0
+
+ val LAMBDA = 0.01 // Regularization coefficient
+
+ // Some COLT objects
+ val factory2D = DoubleFactory2D.dense
+ val factory1D = DoubleFactory1D.dense
+ val algebra = Algebra.DEFAULT
+ val blas = SeqBlas.seqBlas
+
+ def generateR(): DoubleMatrix2D = {
+ val mh = factory2D.random(M, F)
+ val uh = factory2D.random(U, F)
+ return algebra.mult(mh, algebra.transpose(uh))
+ }
+
+ def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
+ us: Array[DoubleMatrix1D]): Double =
+ {
+ val r = factory2D.make(M, U)
+ for (i <- 0 until M; j <- 0 until U) {
+ r.set(i, j, blas.ddot(ms(i), us(j)))
+ }
+ //println("R: " + r)
+ blas.daxpy(-1, targetR, r)
+ val sumSqs = r.aggregate(Functions.plus, Functions.square)
+ return Math.sqrt(sumSqs / (M * U))
+ }
+
+ def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each user that rated the movie
+ for (j <- 0 until U) {
+ val u = us(j)
+ // Add u * u^t to XtX
+ blas.dger(1, u, u, XtX)
+ // Add u * rating to Xty
+ blas.daxpy(R.get(i, j), u, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each movie that the user rated
+ for (i <- 0 until M) {
+ val m = ms(i)
+ // Add m * m^t to XtX
+ blas.dger(1, m, m, XtX)
+ // Add m * rating to Xty
+ blas.daxpy(R.get(i, j), m, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def main(args: Array[String]) {
+ args match {
+ case Array(m, u, f, iters) => {
+ M = m.toInt
+ U = u.toInt
+ F = f.toInt
+ ITERATIONS = iters.toInt
+ }
+ case _ => {
+ System.err.println("Usage: LocalALS <M> <U> <F> <iters>")
+ System.exit(1)
+ }
+ }
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+
+ val R = generateR()
+
+ // Initialize m and u randomly
+ var ms = Array.fromFunction(_ => factory1D.random(F))(M)
+ var us = Array.fromFunction(_ => factory1D.random(F))(U)
+
+ // Iteratively update movies then users
+ for (iter <- 1 to ITERATIONS) {
+ println("Iteration " + iter + ":")
+ ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray
+ us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray
+ println("RMSE = " + rmse(R, ms, us))
+ println()
+ }
+ }
+}
diff --git a/src/examples/LocalFileLR.scala b/src/examples/LocalFileLR.scala
new file mode 100644
index 0000000000..60b4aa8fc4
--- /dev/null
+++ b/src/examples/LocalFileLR.scala
@@ -0,0 +1,36 @@
+import java.util.Random
+import Vector._
+
+object LocalFileLR {
+ val D = 10 // Numer of dimensions
+ val rand = new Random(42)
+
+ case class DataPoint(x: Vector, y: Double)
+
+ def parsePoint(line: String): DataPoint = {
+ val nums = line.split(' ').map(_.toDouble)
+ return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0))
+ }
+
+ def main(args: Array[String]) {
+ val lines = scala.io.Source.fromFile(args(0)).getLines
+ val points = lines.map(parsePoint _)
+ val ITERATIONS = args(1).toInt
+
+ // Initialize w to a random value
+ var w = Vector(D, _ => 2 * rand.nextDouble - 1)
+ println("Initial w: " + w)
+
+ for (i <- 1 to ITERATIONS) {
+ println("On iteration " + i)
+ var gradient = Vector.zeros(D)
+ for (p <- points) {
+ val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
+ gradient += scale * p.x
+ }
+ w -= gradient
+ }
+
+ println("Final w: " + w)
+ }
+}
diff --git a/src/examples/LocalLR.scala b/src/examples/LocalLR.scala
new file mode 100644
index 0000000000..175907e551
--- /dev/null
+++ b/src/examples/LocalLR.scala
@@ -0,0 +1,41 @@
+import java.util.Random
+import Vector._
+
+object LocalLR {
+ val N = 10000 // Number of data points
+ val D = 10 // Numer of dimensions
+ val R = 0.7 // Scaling factor
+ val ITERATIONS = 5
+ val rand = new Random(42)
+
+ case class DataPoint(x: Vector, y: Double)
+
+ def generateData = {
+ def generatePoint(i: Int) = {
+ val y = if(i % 2 == 0) -1 else 1
+ val x = Vector(D, _ => rand.nextGaussian + y * R)
+ DataPoint(x, y)
+ }
+ Array.fromFunction(generatePoint _)(N)
+ }
+
+ def main(args: Array[String]) {
+ val data = generateData
+
+ // Initialize w to a random value
+ var w = Vector(D, _ => 2 * rand.nextDouble - 1)
+ println("Initial w: " + w)
+
+ for (i <- 1 to ITERATIONS) {
+ println("On iteration " + i)
+ var gradient = Vector.zeros(D)
+ for (p <- data) {
+ val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
+ gradient += scale * p.x
+ }
+ w -= gradient
+ }
+
+ println("Final w: " + w)
+ }
+}
diff --git a/src/examples/LocalPi.scala b/src/examples/LocalPi.scala
new file mode 100644
index 0000000000..c83aeed40b
--- /dev/null
+++ b/src/examples/LocalPi.scala
@@ -0,0 +1,14 @@
+import spark._
+import SparkContext._
+
+object LocalPi {
+ def main(args: Array[String]) {
+ var count = 0
+ for (i <- 1 to 100000) {
+ val x = Math.random * 2 - 1
+ val y = Math.random * 2 - 1
+ if (x*x + y*y < 1) count += 1
+ }
+ println("Pi is roughly " + 4 * count / 100000.0)
+ }
+} \ No newline at end of file
diff --git a/src/examples/SleepJob.scala b/src/examples/SleepJob.scala
new file mode 100644
index 0000000000..a5e0ea0dc2
--- /dev/null
+++ b/src/examples/SleepJob.scala
@@ -0,0 +1,19 @@
+import spark._
+
+object SleepJob {
+ def main(args: Array[String]) {
+ if (args.length != 3) {
+ System.err.println("Usage: SleepJob <master> <tasks> <task_duration>");
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "Sleep job")
+ val tasks = args(1).toInt
+ val duration = args(2).toInt
+ def task {
+ val start = System.currentTimeMillis
+ while (System.currentTimeMillis - start < duration * 1000L)
+ Thread.sleep(200)
+ }
+ sc.runTasks(Array.make(tasks, () => task))
+ }
+}
diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala
new file mode 100644
index 0000000000..2fd58ed3a5
--- /dev/null
+++ b/src/examples/SparkALS.scala
@@ -0,0 +1,138 @@
+import java.io.Serializable
+import java.util.Random
+import cern.jet.math._
+import cern.colt.matrix._
+import cern.colt.matrix.linalg._
+import spark._
+
+object SparkALS {
+ // Parameters set through command line arguments
+ var M = 0 // Number of movies
+ var U = 0 // Number of users
+ var F = 0 // Number of features
+ var ITERATIONS = 0
+
+ val LAMBDA = 0.01 // Regularization coefficient
+
+ // Some COLT objects
+ val factory2D = DoubleFactory2D.dense
+ val factory1D = DoubleFactory1D.dense
+ val algebra = Algebra.DEFAULT
+ val blas = SeqBlas.seqBlas
+
+ def generateR(): DoubleMatrix2D = {
+ val mh = factory2D.random(M, F)
+ val uh = factory2D.random(U, F)
+ return algebra.mult(mh, algebra.transpose(uh))
+ }
+
+ def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
+ us: Array[DoubleMatrix1D]): Double =
+ {
+ val r = factory2D.make(M, U)
+ for (i <- 0 until M; j <- 0 until U) {
+ r.set(i, j, blas.ddot(ms(i), us(j)))
+ }
+ //println("R: " + r)
+ blas.daxpy(-1, targetR, r)
+ val sumSqs = r.aggregate(Functions.plus, Functions.square)
+ return Math.sqrt(sumSqs / (M * U))
+ }
+
+ def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val U = us.size
+ val F = us(0).size
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each user that rated the movie
+ for (j <- 0 until U) {
+ val u = us(j)
+ // Add u * u^t to XtX
+ blas.dger(1, u, u, XtX)
+ // Add u * rating to Xty
+ blas.daxpy(R.get(i, j), u, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
+ R: DoubleMatrix2D) : DoubleMatrix1D =
+ {
+ val M = ms.size
+ val F = ms(0).size
+ val XtX = factory2D.make(F, F)
+ val Xty = factory1D.make(F)
+ // For each movie that the user rated
+ for (i <- 0 until M) {
+ val m = ms(i)
+ // Add m * m^t to XtX
+ blas.dger(1, m, m, XtX)
+ // Add m * rating to Xty
+ blas.daxpy(R.get(i, j), m, Xty)
+ }
+ // Add regularization coefs to diagonal terms
+ for (d <- 0 until F) {
+ XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
+ }
+ // Solve it with Cholesky
+ val ch = new CholeskyDecomposition(XtX)
+ val Xty2D = factory2D.make(Xty.toArray, F)
+ val solved2D = ch.solve(Xty2D)
+ return solved2D.viewColumn(0)
+ }
+
+ def main(args: Array[String]) {
+ var host = ""
+ var slices = 0
+ args match {
+ case Array(m, u, f, iters, slices_, host_) => {
+ M = m.toInt
+ U = u.toInt
+ F = f.toInt
+ ITERATIONS = iters.toInt
+ slices = slices_.toInt
+ host = host_
+ }
+ case _ => {
+ System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>")
+ System.exit(1)
+ }
+ }
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+ val spark = new SparkContext(host, "SparkALS")
+
+ val R = generateR()
+
+ // Initialize m and u randomly
+ var ms = Array.fromFunction(_ => factory1D.random(F))(M)
+ var us = Array.fromFunction(_ => factory1D.random(F))(U)
+
+ // Iteratively update movies then users
+ val Rc = spark.broadcast(R)
+ var msb = spark.broadcast(ms)
+ var usb = spark.broadcast(us)
+ for (iter <- 1 to ITERATIONS) {
+ println("Iteration " + iter + ":")
+ ms = spark.parallelize(0 until M, slices)
+ .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value))
+ .toArray
+ msb = spark.broadcast(ms) // Re-broadcast ms because it was updated
+ us = spark.parallelize(0 until U, slices)
+ .map(i => updateUser(i, usb.value(i), msb.value, Rc.value))
+ .toArray
+ usb = spark.broadcast(us) // Re-broadcast us because it was updated
+ println("RMSE = " + rmse(R, ms, us))
+ println()
+ }
+ }
+}
diff --git a/src/examples/SparkHdfsLR.scala b/src/examples/SparkHdfsLR.scala
new file mode 100644
index 0000000000..d0400380bd
--- /dev/null
+++ b/src/examples/SparkHdfsLR.scala
@@ -0,0 +1,50 @@
+import java.util.Random
+import Vector._
+import spark._
+
+object SparkHdfsLR {
+ val D = 10 // Numer of dimensions
+ val rand = new Random(42)
+
+ case class DataPoint(x: Vector, y: Double)
+
+ def parsePoint(line: String): DataPoint = {
+ //val nums = line.split(' ').map(_.toDouble)
+ //return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0))
+ val tok = new java.util.StringTokenizer(line, " ")
+ var y = tok.nextToken.toDouble
+ var x = new Array[Double](D)
+ var i = 0
+ while (i < D) {
+ x(i) = tok.nextToken.toDouble; i += 1
+ }
+ return DataPoint(new Vector(x), y)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: SparkHdfsLR <master> <file> <iters>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "SparkHdfsLR")
+ val lines = sc.textFile(args(1))
+ val points = lines.map(parsePoint _).cache()
+ val ITERATIONS = args(2).toInt
+
+ // Initialize w to a random value
+ var w = Vector(D, _ => 2 * rand.nextDouble - 1)
+ println("Initial w: " + w)
+
+ for (i <- 1 to ITERATIONS) {
+ println("On iteration " + i)
+ val gradient = sc.accumulator(Vector.zeros(D))
+ for (p <- points) {
+ val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
+ gradient += scale * p.x
+ }
+ w -= gradient.value
+ }
+
+ println("Final w: " + w)
+ }
+}
diff --git a/src/examples/SparkLR.scala b/src/examples/SparkLR.scala
new file mode 100644
index 0000000000..34574f5640
--- /dev/null
+++ b/src/examples/SparkLR.scala
@@ -0,0 +1,48 @@
+import java.util.Random
+import Vector._
+import spark._
+
+object SparkLR {
+ val N = 10000 // Number of data points
+ val D = 10 // Numer of dimensions
+ val R = 0.7 // Scaling factor
+ val ITERATIONS = 5
+ val rand = new Random(42)
+
+ case class DataPoint(x: Vector, y: Double)
+
+ def generateData = {
+ def generatePoint(i: Int) = {
+ val y = if(i % 2 == 0) -1 else 1
+ val x = Vector(D, _ => rand.nextGaussian + y * R)
+ DataPoint(x, y)
+ }
+ Array.fromFunction(generatePoint _)(N)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length == 0) {
+ System.err.println("Usage: SparkLR <host> [<slices>]")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "SparkLR")
+ val numSlices = if (args.length > 1) args(1).toInt else 2
+ val data = generateData
+
+ // Initialize w to a random value
+ var w = Vector(D, _ => 2 * rand.nextDouble - 1)
+ println("Initial w: " + w)
+
+ for (i <- 1 to ITERATIONS) {
+ println("On iteration " + i)
+ val gradient = sc.accumulator(Vector.zeros(D))
+ for (p <- sc.parallelize(data, numSlices)) {
+ val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
+ gradient += scale * p.x
+ }
+ w -= gradient.value
+ }
+
+ println("Final w: " + w)
+ }
+}
diff --git a/src/examples/SparkPi.scala b/src/examples/SparkPi.scala
new file mode 100644
index 0000000000..7dbadd1088
--- /dev/null
+++ b/src/examples/SparkPi.scala
@@ -0,0 +1,20 @@
+import spark._
+import SparkContext._
+
+object SparkPi {
+ def main(args: Array[String]) {
+ if (args.length == 0) {
+ System.err.println("Usage: SparkLR <host> [<slices>]")
+ System.exit(1)
+ }
+ val spark = new SparkContext(args(0), "SparkPi")
+ val slices = if (args.length > 1) args(1).toInt else 2
+ var count = spark.accumulator(0)
+ for (i <- spark.parallelize(1 to 100000, slices)) {
+ val x = Math.random * 2 - 1
+ val y = Math.random * 2 - 1
+ if (x*x + y*y < 1) count += 1
+ }
+ println("Pi is roughly " + 4 * count.value / 100000.0)
+ }
+} \ No newline at end of file
diff --git a/src/examples/Vector.scala b/src/examples/Vector.scala
new file mode 100644
index 0000000000..0ae2cbc6e8
--- /dev/null
+++ b/src/examples/Vector.scala
@@ -0,0 +1,63 @@
+@serializable class Vector(val elements: Array[Double]) {
+ def length = elements.length
+
+ def apply(index: Int) = elements(index)
+
+ def + (other: Vector): Vector = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ return Vector(length, i => this(i) + other(i))
+ }
+
+ def - (other: Vector): Vector = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ return Vector(length, i => this(i) - other(i))
+ }
+
+ def dot(other: Vector): Double = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ for (i <- 0 until length)
+ ans += this(i) * other(i)
+ return ans
+ }
+
+ def * ( scale: Double): Vector = Vector(length, i => this(i) * scale)
+
+ def unary_- = this * -1
+
+ def sum = elements.reduceLeft(_ + _)
+
+ override def toString = elements.mkString("(", ", ", ")")
+
+}
+
+object Vector {
+ def apply(elements: Array[Double]) = new Vector(elements)
+
+ def apply(elements: Double*) = new Vector(elements.toArray)
+
+ def apply(length: Int, initializer: Int => Double): Vector = {
+ val elements = new Array[Double](length)
+ for (i <- 0 until length)
+ elements(i) = initializer(i)
+ return new Vector(elements)
+ }
+
+ def zeros(length: Int) = new Vector(new Array[Double](length))
+
+ def ones(length: Int) = Vector(length, _ => 1)
+
+ class Multiplier(num: Double) {
+ def * (vec: Vector) = vec * num
+ }
+
+ implicit def doubleToMultiplier(num: Double) = new Multiplier(num)
+
+ implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] {
+ def add(t1: Vector, t2: Vector) = t1 + t2
+ def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
+ }
+}
diff --git a/src/java/spark/compress/lzf/LZF.java b/src/java/spark/compress/lzf/LZF.java
new file mode 100644
index 0000000000..294a0494ec
--- /dev/null
+++ b/src/java/spark/compress/lzf/LZF.java
@@ -0,0 +1,27 @@
+package spark.compress.lzf;
+
+public class LZF {
+ private static boolean loaded;
+
+ static {
+ try {
+ System.loadLibrary("spark_native");
+ loaded = true;
+ } catch(Throwable t) {
+ System.out.println("Failed to load native LZF library: " + t.toString());
+ loaded = false;
+ }
+ }
+
+ public static boolean isLoaded() {
+ return loaded;
+ }
+
+ public static native int compress(
+ byte[] in, int inOff, int inLen,
+ byte[] out, int outOff, int outLen);
+
+ public static native int decompress(
+ byte[] in, int inOff, int inLen,
+ byte[] out, int outOff, int outLen);
+}
diff --git a/src/java/spark/compress/lzf/LZFInputStream.java b/src/java/spark/compress/lzf/LZFInputStream.java
new file mode 100644
index 0000000000..16bc687489
--- /dev/null
+++ b/src/java/spark/compress/lzf/LZFInputStream.java
@@ -0,0 +1,180 @@
+package spark.compress.lzf;
+
+import java.io.EOFException;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class LZFInputStream extends FilterInputStream {
+ private static final int MAX_BLOCKSIZE = 1024 * 64 - 1;
+ private static final int MAX_HDR_SIZE = 7;
+
+ private byte[] inBuf; // Holds data to decompress (including header)
+ private byte[] outBuf; // Holds decompressed data to output
+ private int outPos; // Current position in outBuf
+ private int outSize; // Total amount of data in outBuf
+
+ private boolean closed;
+ private boolean reachedEof;
+
+ private byte[] singleByte = new byte[1];
+
+ public LZFInputStream(InputStream in) {
+ super(in);
+ if (in == null)
+ throw new NullPointerException();
+ inBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE];
+ outBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE];
+ outPos = 0;
+ outSize = 0;
+ }
+
+ private void ensureOpen() throws IOException {
+ if (closed) throw new IOException("Stream closed");
+ }
+
+ @Override
+ public int read() throws IOException {
+ ensureOpen();
+ int count = read(singleByte, 0, 1);
+ return (count == -1 ? -1 : singleByte[0] & 0xFF);
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ ensureOpen();
+ if ((off | len | (off + len) | (b.length - (off + len))) < 0)
+ throw new IndexOutOfBoundsException();
+
+ int totalRead = 0;
+
+ // Start with the current block in outBuf, and read and decompress any
+ // further blocks necessary. Instead of trying to decompress directly to b
+ // when b is large, we always use outBuf as an intermediate holding space
+ // in case GetPrimitiveArrayCritical decides to copy arrays instead of
+ // pinning them, which would cause b to be copied repeatedly into C-land.
+ while (len > 0) {
+ if (outPos == outSize) {
+ readNextBlock();
+ if (reachedEof)
+ return totalRead == 0 ? -1 : totalRead;
+ }
+ int amtToCopy = Math.min(outSize - outPos, len);
+ System.arraycopy(outBuf, outPos, b, off, amtToCopy);
+ off += amtToCopy;
+ len -= amtToCopy;
+ outPos += amtToCopy;
+ totalRead += amtToCopy;
+ }
+
+ return totalRead;
+ }
+
+ // Read len bytes from this.in to a buffer, stopping only if EOF is reached
+ private int readFully(byte[] b, int off, int len) throws IOException {
+ int totalRead = 0;
+ while (len > 0) {
+ int amt = in.read(b, off, len);
+ if (amt == -1)
+ break;
+ off += amt;
+ len -= amt;
+ totalRead += amt;
+ }
+ return totalRead;
+ }
+
+ // Read the next block from the underlying InputStream into outBuf,
+ // setting outPos and outSize, or set reachedEof if the stream ends.
+ private void readNextBlock() throws IOException {
+ // Read first 5 bytes of header
+ int count = readFully(inBuf, 0, 5);
+ if (count == 0) {
+ reachedEof = true;
+ return;
+ } else if (count < 5) {
+ throw new EOFException("Truncated LZF block header");
+ }
+
+ // Check magic bytes
+ if (inBuf[0] != 'Z' || inBuf[1] != 'V')
+ throw new IOException("Wrong magic bytes in LZF block header");
+
+ // Read the block
+ if (inBuf[2] == 0) {
+ // Uncompressed block - read directly to outBuf
+ int size = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF);
+ if (readFully(outBuf, 0, size) != size)
+ throw new EOFException("EOF inside LZF block");
+ outPos = 0;
+ outSize = size;
+ } else if (inBuf[2] == 1) {
+ // Compressed block - read to inBuf and decompress
+ if (readFully(inBuf, 5, 2) != 2)
+ throw new EOFException("Truncated LZF block header");
+ int csize = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF);
+ int usize = ((inBuf[5] & 0xFF) << 8) | (inBuf[6] & 0xFF);
+ if (readFully(inBuf, 7, csize) != csize)
+ throw new EOFException("Truncated LZF block");
+ if (LZF.decompress(inBuf, 7, csize, outBuf, 0, usize) != usize)
+ throw new IOException("Corrupt LZF data stream");
+ outPos = 0;
+ outSize = usize;
+ } else {
+ throw new IOException("Unknown block type in LZF block header");
+ }
+ }
+
+ /**
+ * Returns 0 after EOF has been reached, otherwise always return 1.
+ *
+ * Programs should not count on this method to return the actual number
+ * of bytes that could be read without blocking.
+ */
+ @Override
+ public int available() throws IOException {
+ ensureOpen();
+ return reachedEof ? 0 : 1;
+ }
+
+ // TODO: Skip complete chunks without decompressing them?
+ @Override
+ public long skip(long n) throws IOException {
+ ensureOpen();
+ if (n < 0)
+ throw new IllegalArgumentException("negative skip length");
+ byte[] buf = new byte[512];
+ long skipped = 0;
+ while (skipped < n) {
+ int len = (int) Math.min(n - skipped, buf.length);
+ len = read(buf, 0, len);
+ if (len == -1) {
+ reachedEof = true;
+ break;
+ }
+ skipped += len;
+ }
+ return skipped;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (!closed) {
+ in.close();
+ closed = true;
+ }
+ }
+
+ @Override
+ public boolean markSupported() {
+ return false;
+ }
+
+ @Override
+ public void mark(int readLimit) {}
+
+ @Override
+ public void reset() throws IOException {
+ throw new IOException("mark/reset not supported");
+ }
+}
diff --git a/src/java/spark/compress/lzf/LZFOutputStream.java b/src/java/spark/compress/lzf/LZFOutputStream.java
new file mode 100644
index 0000000000..5f65e95d2a
--- /dev/null
+++ b/src/java/spark/compress/lzf/LZFOutputStream.java
@@ -0,0 +1,85 @@
+package spark.compress.lzf;
+
+import java.io.FilterOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+public class LZFOutputStream extends FilterOutputStream {
+ private static final int BLOCKSIZE = 1024 * 64 - 1;
+ private static final int MAX_HDR_SIZE = 7;
+
+ private byte[] inBuf; // Holds input data to be compressed
+ private byte[] outBuf; // Holds compressed data to be written
+ private int inPos; // Current position in inBuf
+
+ public LZFOutputStream(OutputStream out) {
+ super(out);
+ inBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE];
+ outBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE];
+ inPos = MAX_HDR_SIZE;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ inBuf[inPos++] = (byte) b;
+ if (inPos == inBuf.length)
+ compressAndSendBlock();
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ if ((off | len | (off + len) | (b.length - (off + len))) < 0)
+ throw new IndexOutOfBoundsException();
+
+ // If we're given a large array, copy it piece by piece into inBuf and
+ // write one BLOCKSIZE at a time. This is done to prevent the JNI code
+ // from copying the whole array repeatedly if GetPrimitiveArrayCritical
+ // decides to copy instead of pinning.
+ while (inPos + len >= inBuf.length) {
+ int amtToCopy = inBuf.length - inPos;
+ System.arraycopy(b, off, inBuf, inPos, amtToCopy);
+ inPos += amtToCopy;
+ compressAndSendBlock();
+ off += amtToCopy;
+ len -= amtToCopy;
+ }
+
+ // Copy the remaining (incomplete) block into inBuf
+ System.arraycopy(b, off, inBuf, inPos, len);
+ inPos += len;
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (inPos > MAX_HDR_SIZE)
+ compressAndSendBlock();
+ out.flush();
+ }
+
+ // Send the data in inBuf, and reset inPos to start writing a new block.
+ private void compressAndSendBlock() throws IOException {
+ int us = inPos - MAX_HDR_SIZE;
+ int maxcs = us > 4 ? us - 4 : us;
+ int cs = LZF.compress(inBuf, MAX_HDR_SIZE, us, outBuf, MAX_HDR_SIZE, maxcs);
+ if (cs != 0) {
+ // Compression made the data smaller; use type 1 header
+ outBuf[0] = 'Z';
+ outBuf[1] = 'V';
+ outBuf[2] = 1;
+ outBuf[3] = (byte) (cs >> 8);
+ outBuf[4] = (byte) (cs & 0xFF);
+ outBuf[5] = (byte) (us >> 8);
+ outBuf[6] = (byte) (us & 0xFF);
+ out.write(outBuf, 0, 7 + cs);
+ } else {
+ // Compression didn't help; use type 0 header and uncompressed data
+ inBuf[2] = 'Z';
+ inBuf[3] = 'V';
+ inBuf[4] = 0;
+ inBuf[5] = (byte) (us >> 8);
+ inBuf[6] = (byte) (us & 0xFF);
+ out.write(inBuf, 2, 5 + us);
+ }
+ inPos = MAX_HDR_SIZE;
+ }
+}
diff --git a/src/native/Makefile b/src/native/Makefile
new file mode 100644
index 0000000000..331d3e6057
--- /dev/null
+++ b/src/native/Makefile
@@ -0,0 +1,30 @@
+CC = gcc
+#JAVA_HOME = /usr/lib/jvm/java-6-sun
+OS_NAME = linux
+
+CFLAGS = -fPIC -O3 -funroll-all-loops
+
+SPARK = ../..
+
+LZF = $(SPARK)/third_party/liblzf-3.5
+
+LIB = libspark_native.so
+
+all: $(LIB)
+
+spark_compress_lzf_LZF.h: $(SPARK)/classes/spark/compress/lzf/LZF.class
+ifeq ($(JAVA_HOME),)
+ $(error JAVA_HOME is not set)
+else
+ $(JAVA_HOME)/bin/javah -classpath $(SPARK)/classes spark.compress.lzf.LZF
+endif
+
+$(LIB): spark_compress_lzf_LZF.h spark_compress_lzf_LZF.c
+ $(CC) $(CFLAGS) -shared -o $@ spark_compress_lzf_LZF.c \
+ -I $(JAVA_HOME)/include -I $(JAVA_HOME)/include/$(OS_NAME) \
+ -I $(LZF) $(LZF)/lzf_c.c $(LZF)/lzf_d.c
+
+clean:
+ rm -f spark_compress_lzf_LZF.h $(LIB)
+
+.PHONY: all clean
diff --git a/src/native/spark_compress_lzf_LZF.c b/src/native/spark_compress_lzf_LZF.c
new file mode 100644
index 0000000000..c2a59def3e
--- /dev/null
+++ b/src/native/spark_compress_lzf_LZF.c
@@ -0,0 +1,90 @@
+#include "spark_compress_lzf_LZF.h"
+#include <lzf.h>
+
+
+/* Helper function to throw an exception */
+static void throwException(JNIEnv *env, const char* className) {
+ jclass cls = (*env)->FindClass(env, className);
+ if (cls != 0) /* If cls is null, an exception was already thrown */
+ (*env)->ThrowNew(env, cls, "");
+}
+
+
+/*
+ * Since LZF.compress() and LZF.decompress() have the same signatures
+ * and differ only in which lzf_ function they call, implement both in a
+ * single function and pass it a pointer to the correct lzf_ function.
+ */
+static jint callCompressionFunction
+ (unsigned int (*func)(const void *const, unsigned int, void *, unsigned int),
+ JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
+ jbyteArray outArray, jint outOff, jint outLen)
+{
+ jint inCap;
+ jint outCap;
+ jbyte *inData = 0;
+ jbyte *outData = 0;
+ jint ret;
+ jint s;
+
+ if (!inArray || !outArray) {
+ throwException(env, "java/lang/NullPointerException");
+ goto cleanup;
+ }
+
+ inCap = (*env)->GetArrayLength(env, inArray);
+ outCap = (*env)->GetArrayLength(env, outArray);
+
+ // Check if any of the offset/length pairs is invalid; we do this by OR'ing
+ // things we don't want to be negative and seeing if the result is negative
+ s = inOff | inLen | (inOff + inLen) | (inCap - (inOff + inLen)) |
+ outOff | outLen | (outOff + outLen) | (outCap - (outOff + outLen));
+ if (s < 0) {
+ throwException(env, "java/lang/IndexOutOfBoundsException");
+ goto cleanup;
+ }
+
+ inData = (*env)->GetPrimitiveArrayCritical(env, inArray, 0);
+ outData = (*env)->GetPrimitiveArrayCritical(env, outArray, 0);
+
+ if (!inData || !outData) {
+ // Out of memory - JVM will throw OutOfMemoryError
+ goto cleanup;
+ }
+
+ ret = func(inData + inOff, inLen, outData + outOff, outLen);
+
+cleanup:
+ if (inData)
+ (*env)->ReleasePrimitiveArrayCritical(env, inArray, inData, 0);
+ if (outData)
+ (*env)->ReleasePrimitiveArrayCritical(env, outArray, outData, 0);
+
+ return ret;
+}
+
+/*
+ * Class: spark_compress_lzf_LZF
+ * Method: compress
+ * Signature: ([B[B)I
+ */
+JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_compress
+ (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
+ jbyteArray outArray, jint outOff, jint outLen)
+{
+ return callCompressionFunction(lzf_compress, env, cls,
+ inArray, inOff, inLen, outArray,outOff, outLen);
+}
+
+/*
+ * Class: spark_compress_lzf_LZF
+ * Method: decompress
+ * Signature: ([B[B)I
+ */
+JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_decompress
+ (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
+ jbyteArray outArray, jint outOff, jint outLen)
+{
+ return callCompressionFunction(lzf_decompress, env, cls,
+ inArray, inOff, inLen, outArray,outOff, outLen);
+}
diff --git a/src/scala/spark/Accumulators.scala b/src/scala/spark/Accumulators.scala
new file mode 100644
index 0000000000..3e4cd4935a
--- /dev/null
+++ b/src/scala/spark/Accumulators.scala
@@ -0,0 +1,71 @@
+package spark
+
+import java.io._
+
+import scala.collection.mutable.Map
+
+@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T])
+{
+ val id = Accumulators.newId
+ @transient var value_ = initialValue
+ var deserialized = false
+
+ Accumulators.register(this)
+
+ def += (term: T) { value_ = param.add(value_, term) }
+ def value = this.value_
+ def value_= (t: T) {
+ if (!deserialized) value_ = t
+ else throw new UnsupportedOperationException("Can't use value_= in task")
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject
+ value_ = param.zero(initialValue)
+ deserialized = true
+ Accumulators.register(this)
+ }
+
+ override def toString = value_.toString
+}
+
+@serializable trait AccumulatorParam[T] {
+ def add(t1: T, t2: T): T
+ def zero(initialValue: T): T
+}
+
+// TODO: The multi-thread support in accumulators is kind of lame; check
+// if there's a more intuitive way of doing it right
+private object Accumulators
+{
+ // TODO: Use soft references? => need to make readObject work properly then
+ val accums = Map[(Thread, Long), Accumulator[_]]()
+ var lastId: Long = 0
+
+ def newId: Long = synchronized { lastId += 1; return lastId }
+
+ def register(a: Accumulator[_]): Unit = synchronized {
+ accums((currentThread, a.id)) = a
+ }
+
+ def clear: Unit = synchronized {
+ accums.retain((key, accum) => key._1 != currentThread)
+ }
+
+ def values: Map[Long, Any] = synchronized {
+ val ret = Map[Long, Any]()
+ for(((thread, id), accum) <- accums if thread == currentThread)
+ ret(id) = accum.value
+ return ret
+ }
+
+ def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
+ for ((id, value) <- values) {
+ if (accums.contains((thread, id))) {
+ val accum = accums((thread, id))
+ accum.asInstanceOf[Accumulator[Any]] += value
+ }
+ }
+ }
+}
diff --git a/src/scala/spark/Cached.scala b/src/scala/spark/Cached.scala
new file mode 100644
index 0000000000..8113340e1f
--- /dev/null
+++ b/src/scala/spark/Cached.scala
@@ -0,0 +1,110 @@
+package spark
+
+import java.io._
+import java.net.URI
+import java.util.UUID
+
+import com.google.common.collect.MapMaker
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+@serializable class Cached[T](@transient var value_ : T, local: Boolean) {
+ val uuid = UUID.randomUUID()
+ def value = value_
+
+ Cache.synchronized { Cache.values.put(uuid, value_) }
+
+ if (!local) writeCacheFile()
+
+ private def writeCacheFile() {
+ val out = new ObjectOutputStream(Cache.openFileForWriting(uuid))
+ out.writeObject(value_)
+ out.close()
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject
+ Cache.synchronized {
+ val cachedVal = Cache.values.get(uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ val start = System.nanoTime
+ val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid))
+ value_ = fileIn.readObject().asInstanceOf[T]
+ Cache.values.put(uuid, value_)
+ fileIn.close()
+ val time = (System.nanoTime - start) / 1e9
+ println("Reading cached variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ override def toString = "spark.Cached(" + uuid + ")"
+}
+
+private object Cache {
+ val values = new MapMaker().softValues().makeMap[UUID, Any]()
+
+ private var initialized = false
+ private var fileSystem: FileSystem = null
+ private var workDir: String = null
+ private var compress: Boolean = false
+ private var bufferSize: Int = 65536
+
+ // Will be called by SparkContext or Executor before using cache
+ def initialize() {
+ synchronized {
+ if (!initialized) {
+ bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val dfs = System.getProperty("spark.dfs", "file:///")
+ if (!dfs.startsWith("file://")) {
+ val conf = new Configuration()
+ conf.setInt("io.file.buffer.size", bufferSize)
+ val rep = System.getProperty("spark.dfs.replication", "3").toInt
+ conf.setInt("dfs.replication", rep)
+ fileSystem = FileSystem.get(new URI(dfs), conf)
+ }
+ workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+ compress = System.getProperty("spark.compress", "false").toBoolean
+ initialized = true
+ }
+ }
+ }
+
+ private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid)
+
+ def openFileForReading(uuid: UUID): InputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.open(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileInputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFInputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedInputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+
+ def openFileForWriting(uuid: UUID): OutputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.create(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileOutputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFOutputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedOutputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+}
diff --git a/src/scala/spark/ClosureCleaner.scala b/src/scala/spark/ClosureCleaner.scala
new file mode 100644
index 0000000000..c5663901b3
--- /dev/null
+++ b/src/scala/spark/ClosureCleaner.scala
@@ -0,0 +1,157 @@
+package spark
+
+import scala.collection.mutable.Map
+import scala.collection.mutable.Set
+
+import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
+import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.Opcodes._
+
+
+object ClosureCleaner {
+ private def getClassReader(cls: Class[_]): ClassReader = new ClassReader(
+ cls.getResourceAsStream(cls.getName.replaceFirst("^.*\\.", "") + ".class"))
+
+ private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ return f.getType :: getOuterClasses(f.get(obj))
+ }
+ return Nil
+ }
+
+ private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ return f.get(obj) :: getOuterObjects(f.get(obj))
+ }
+ return Nil
+ }
+
+ private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
+ val seen = Set[Class[_]](obj.getClass)
+ var stack = List[Class[_]](obj.getClass)
+ while (!stack.isEmpty) {
+ val cr = getClassReader(stack.head)
+ stack = stack.tail
+ val set = Set[Class[_]]()
+ cr.accept(new InnerClosureFinder(set), 0)
+ for (cls <- set -- seen) {
+ seen += cls
+ stack = cls :: stack
+ }
+ }
+ return (seen - obj.getClass).toList
+ }
+
+ private def createNullValue(cls: Class[_]): AnyRef = {
+ if (cls.isPrimitive)
+ new java.lang.Byte(0: Byte) // Should be convertible to any primitive type
+ else
+ null
+ }
+
+ def clean(func: AnyRef): Unit = {
+ // TODO: cache outerClasses / innerClasses / accessedFields
+ val outerClasses = getOuterClasses(func)
+ val innerClasses = getInnerClasses(func)
+ val outerObjects = getOuterObjects(func)
+
+ val accessedFields = Map[Class[_], Set[String]]()
+ for (cls <- outerClasses)
+ accessedFields(cls) = Set[String]()
+ for (cls <- func.getClass :: innerClasses)
+ getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
+
+ var outer: AnyRef = null
+ for ((cls, obj) <- (outerClasses zip outerObjects).reverse) {
+ outer = instantiateClass(cls, outer);
+ for (fieldName <- accessedFields(cls)) {
+ val field = cls.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ val value = field.get(obj)
+ //println("1: Setting " + fieldName + " on " + cls + " to " + value);
+ field.set(outer, value)
+ }
+ }
+
+ if (outer != null) {
+ //println("2: Setting $outer on " + func.getClass + " to " + outer);
+ val field = func.getClass.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(func, outer)
+ }
+ }
+
+ private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = {
+ if (spark.repl.Main.interp == null) {
+ // This is a bona fide closure class, whose constructor has no effects
+ // other than to set its fields, so use its constructor
+ val cons = cls.getConstructors()(0)
+ val params = cons.getParameterTypes.map(createNullValue).toArray
+ if (outer != null)
+ params(0) = outer // First param is always outer object
+ return cons.newInstance(params: _*).asInstanceOf[AnyRef]
+ } else {
+ // Use reflection to instantiate object without calling constructor
+ val rf = sun.reflect.ReflectionFactory.getReflectionFactory();
+ val parentCtor = classOf[java.lang.Object].getDeclaredConstructor();
+ val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
+ val obj = newCtor.newInstance().asInstanceOf[AnyRef];
+ if (outer != null) {
+ //println("3: Setting $outer on " + cls + " to " + outer);
+ val field = cls.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(obj, outer)
+ }
+ return obj
+ }
+ }
+}
+
+
+class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new EmptyVisitor {
+ override def visitFieldInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ if (op == GETFIELD)
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
+ output(cl) += name
+ }
+
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer"))
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
+ output(cl) += name
+ }
+ }
+ }
+}
+
+
+class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+ var myName: String = null
+
+ override def visit(version: Int, access: Int, name: String, sig: String,
+ superName: String, interfaces: Array[String]) {
+ myName = name
+ }
+
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new EmptyVisitor {
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ val argTypes = Type.getArgumentTypes(desc)
+ if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
+ && argTypes(0).toString.startsWith("L") // is it an object?
+ && argTypes(0).getInternalName == myName)
+ output += Class.forName(owner.replace('/', '.'), false,
+ Thread.currentThread.getContextClassLoader)
+ }
+ }
+ }
+}
diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala
new file mode 100644
index 0000000000..4cc8f00aa9
--- /dev/null
+++ b/src/scala/spark/Executor.scala
@@ -0,0 +1,70 @@
+package spark
+
+import java.util.concurrent.{Executors, ExecutorService}
+
+import nexus.{ExecutorArgs, TaskDescription, TaskState, TaskStatus}
+
+object Executor {
+ def main(args: Array[String]) {
+ System.loadLibrary("nexus")
+
+ val exec = new nexus.Executor() {
+ var classLoader: ClassLoader = null
+ var threadPool: ExecutorService = null
+
+ override def init(args: ExecutorArgs) {
+ // Read spark.* system properties
+ val props = Utils.deserialize[Array[(String, String)]](args.getData)
+ for ((key, value) <- props)
+ System.setProperty(key, value)
+
+ // Initialize cache (uses some properties read above)
+ Cache.initialize()
+
+ // If the REPL is in use, create a ClassLoader that will be able to
+ // read new classes defined by the REPL as the user types code
+ classLoader = this.getClass.getClassLoader
+ val classDir = System.getProperty("spark.repl.classdir")
+ if (classDir != null) {
+ println("Using REPL classdir: " + classDir)
+ classLoader = new repl.ExecutorClassLoader(classDir, classLoader)
+ }
+ Thread.currentThread.setContextClassLoader(classLoader)
+
+ // Start worker thread pool (they will inherit our context ClassLoader)
+ threadPool = Executors.newCachedThreadPool()
+ }
+
+ override def startTask(desc: TaskDescription) {
+ // Pull taskId and arg out of TaskDescription because it won't be a
+ // valid pointer after this method call (TODO: fix this in C++/SWIG)
+ val taskId = desc.getTaskId
+ val arg = desc.getArg
+ threadPool.execute(new Runnable() {
+ def run() = {
+ println("Running task ID " + taskId)
+ try {
+ Accumulators.clear
+ val task = Utils.deserialize[Task[Any]](arg, classLoader)
+ val value = task.run
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates)
+ sendStatusUpdate(new TaskStatus(
+ taskId, TaskState.TASK_FINISHED, Utils.serialize(result)))
+ println("Finished task ID " + taskId)
+ } catch {
+ case e: Exception => {
+ // TODO: Handle errors in tasks less dramatically
+ System.err.println("Exception in task ID " + taskId + ":")
+ e.printStackTrace
+ System.exit(1)
+ }
+ }
+ }
+ })
+ }
+ }
+
+ exec.run()
+ }
+}
diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala
new file mode 100644
index 0000000000..8050683f99
--- /dev/null
+++ b/src/scala/spark/HdfsFile.scala
@@ -0,0 +1,277 @@
+package spark
+
+import java.io._
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.ConcurrentHashMap
+import java.util.HashSet
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.Map
+
+import nexus._
+
+import com.google.common.collect.MapMaker
+
+import org.apache.hadoop.io.ObjectWritable
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputSplit
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapred.RecordReader
+import org.apache.hadoop.mapred.Reporter
+
+@serializable
+abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
+ def splits: Array[Split]
+ def iterator(split: Split): Iterator[T]
+ def prefers(split: Split, slot: SlaveOffer): Boolean
+
+ def taskStarted(split: Split, slot: SlaveOffer) {}
+
+ def sparkContext = sc
+
+ def foreach(f: T => Unit) {
+ val cleanF = sc.clean(f)
+ val tasks = splits.map(s => new ForeachTask(this, s, cleanF)).toArray
+ sc.runTaskObjects(tasks)
+ }
+
+ def toArray: Array[T] = {
+ val tasks = splits.map(s => new GetTask(this, s))
+ val results = sc.runTaskObjects(tasks)
+ Array.concat(results: _*)
+ }
+
+ def reduce(f: (T, T) => T): T = {
+ val cleanF = sc.clean(f)
+ val tasks = splits.map(s => new ReduceTask(this, s, f))
+ val results = new ArrayBuffer[T]
+ for (option <- sc.runTaskObjects(tasks); elem <- option)
+ results += elem
+ if (results.size == 0)
+ throw new UnsupportedOperationException("empty collection")
+ else
+ return results.reduceLeft(f)
+ }
+
+ def take(num: Int): Array[T] = {
+ if (num == 0)
+ return new Array[T](0)
+ val buf = new ArrayBuffer[T]
+ for (split <- splits; elem <- iterator(split)) {
+ buf += elem
+ if (buf.length == num)
+ return buf.toArray
+ }
+ return buf.toArray
+ }
+
+ def first: T = take(1) match {
+ case Array(t) => t
+ case _ => throw new UnsupportedOperationException("empty collection")
+ }
+
+ def map[U](f: T => U) = new MappedFile(this, sc.clean(f))
+ def filter(f: T => Boolean) = new FilteredFile(this, sc.clean(f))
+ def cache() = new CachedFile(this)
+
+ def count(): Long =
+ try { map(x => 1L).reduce(_+_) }
+ catch { case e: UnsupportedOperationException => 0L }
+}
+
+@serializable
+abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split],
+ val split: Split)
+extends Task[U] {
+ override def prefers(slot: SlaveOffer) = file.prefers(split, slot)
+ override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) }
+}
+
+class ForeachTask[T, Split](file: DistributedFile[T, Split],
+ split: Split, func: T => Unit)
+extends FileTask[Unit, T, Split](file, split) {
+ override def run() {
+ println("Processing " + split)
+ file.iterator(split).foreach(func)
+ }
+}
+
+class GetTask[T, Split](file: DistributedFile[T, Split], split: Split)
+extends FileTask[Array[T], T, Split](file, split) {
+ override def run(): Array[T] = {
+ println("Processing " + split)
+ file.iterator(split).collect.toArray
+ }
+}
+
+class ReduceTask[T, Split](file: DistributedFile[T, Split],
+ split: Split, f: (T, T) => T)
+extends FileTask[Option[T], T, Split](file, split) {
+ override def run(): Option[T] = {
+ println("Processing " + split)
+ val iter = file.iterator(split)
+ if (iter.hasNext)
+ Some(iter.reduceLeft(f))
+ else
+ None
+ }
+}
+
+class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U)
+extends DistributedFile[U, Split](prev.sparkContext) {
+ override def splits = prev.splits
+ override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+ override def iterator(split: Split) = prev.iterator(split).map(f)
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean)
+extends DistributedFile[T, Split](prev.sparkContext) {
+ override def splits = prev.splits
+ override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+ override def iterator(split: Split) = prev.iterator(split).filter(f)
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+class CachedFile[T, Split](prev: DistributedFile[T, Split])
+extends DistributedFile[T, Split](prev.sparkContext) {
+ val id = CachedFile.newId()
+ @transient val cacheLocs = Map[Split, List[Int]]()
+
+ override def splits = prev.splits
+
+ override def prefers(split: Split, slot: SlaveOffer): Boolean = {
+ if (cacheLocs.contains(split))
+ cacheLocs(split).contains(slot.getSlaveId)
+ else
+ prev.prefers(split, slot)
+ }
+
+ override def iterator(split: Split): Iterator[T] = {
+ val key = id + "::" + split.toString
+ val cache = CachedFile.cache
+ val loading = CachedFile.loading
+ val cachedVal = cache.get(key)
+ if (cachedVal != null) {
+ // Split is in cache, so just return its values
+ return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
+ } else {
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ return Iterator.fromArray(cache.get(key).asInstanceOf[Array[T]])
+ } else {
+ loading.add(key)
+ }
+ }
+ // If we got here, we have to load the split
+ println("Loading and caching " + split)
+ val array = prev.iterator(split).collect.toArray
+ cache.put(key, array)
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ return Iterator.fromArray(array)
+ }
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) {
+ val oldList = cacheLocs.getOrElse(split, Nil)
+ val slaveId = slot.getSlaveId
+ if (!oldList.contains(slaveId))
+ cacheLocs(split) = slaveId :: oldList
+ }
+}
+
+private object CachedFile {
+ val nextId = new AtomicLong(0) // Generates IDs for mapped files (on master)
+ def newId() = nextId.getAndIncrement()
+
+ // Stores map results for various splits locally (on workers)
+ val cache = new MapMaker().softValues().makeMap[String, AnyRef]()
+
+ // Remembers which splits are currently being loaded (on workers)
+ val loading = new HashSet[String]
+}
+
+class HdfsSplit(@transient s: InputSplit)
+extends SerializableWritable[InputSplit](s)
+
+class HdfsTextFile(sc: SparkContext, path: String)
+extends DistributedFile[String, HdfsSplit](sc) {
+ @transient val conf = new JobConf()
+ @transient val inputFormat = new TextInputFormat()
+
+ FileInputFormat.setInputPaths(conf, path)
+ ConfigureLock.synchronized { inputFormat.configure(conf) }
+
+ @transient val splits_ =
+ inputFormat.getSplits(conf, 2).map(new HdfsSplit(_)).toArray
+
+ override def splits = splits_
+
+ override def iterator(split: HdfsSplit) = new Iterator[String] {
+ var reader: RecordReader[LongWritable, Text] = null
+ ConfigureLock.synchronized {
+ val conf = new JobConf()
+ conf.set("io.file.buffer.size",
+ System.getProperty("spark.buffer.size", "65536"))
+ val tif = new TextInputFormat()
+ tif.configure(conf)
+ reader = tif.getRecordReader(split.value, conf, Reporter.NULL)
+ }
+ val lineNum = new LongWritable()
+ val text = new Text()
+ var gotNext = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!gotNext) {
+ finished = !reader.next(lineNum, text)
+ gotNext = true
+ }
+ !finished
+ }
+
+ override def next: String = {
+ if (!gotNext)
+ finished = !reader.next(lineNum, text)
+ if (finished)
+ throw new java.util.NoSuchElementException("end of stream")
+ gotNext = false
+ text.toString
+ }
+ }
+
+ override def prefers(split: HdfsSplit, slot: SlaveOffer) =
+ split.value.getLocations().contains(slot.getHost)
+}
+
+object ConfigureLock {}
+
+@serializable
+class SerializableWritable[T <: Writable](@transient var t: T) {
+ def value = t
+ override def toString = t.toString
+
+ private def writeObject(out: ObjectOutputStream) {
+ out.defaultWriteObject()
+ new ObjectWritable(t).write(out)
+ }
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ val ow = new ObjectWritable()
+ ow.setConf(new JobConf())
+ ow.readFields(in)
+ t = ow.get().asInstanceOf[T]
+ }
+}
diff --git a/src/scala/spark/LocalScheduler.scala b/src/scala/spark/LocalScheduler.scala
new file mode 100644
index 0000000000..35bfdde09f
--- /dev/null
+++ b/src/scala/spark/LocalScheduler.scala
@@ -0,0 +1,65 @@
+package spark
+
+import java.util.concurrent._
+
+import scala.collection.mutable.Map
+
+// A simple Scheduler implementation that runs tasks locally in a thread pool.
+private class LocalScheduler(threads: Int) extends Scheduler {
+ var threadPool: ExecutorService =
+ Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+
+ override def start() {}
+
+ override def waitForRegister() {}
+
+ override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
+ val futures = new Array[Future[TaskResult[T]]](tasks.length)
+
+ for (i <- 0 until tasks.length) {
+ futures(i) = threadPool.submit(new Callable[TaskResult[T]]() {
+ def call(): TaskResult[T] = {
+ println("Running task " + i)
+ try {
+ // Serialize and deserialize the task so that accumulators are
+ // changed to thread-local ones; this adds a bit of unnecessary
+ // overhead but matches how the Nexus Executor works
+ Accumulators.clear
+ val bytes = Utils.serialize(tasks(i))
+ println("Size of task " + i + " is " + bytes.size + " bytes")
+ val task = Utils.deserialize[Task[T]](
+ bytes, currentThread.getContextClassLoader)
+ val value = task.run
+ val accumUpdates = Accumulators.values
+ println("Finished task " + i)
+ new TaskResult[T](value, accumUpdates)
+ } catch {
+ case e: Exception => {
+ // TODO: Do something nicer here
+ System.err.println("Exception in task " + i + ":")
+ e.printStackTrace()
+ System.exit(1)
+ null
+ }
+ }
+ }
+ })
+ }
+
+ val taskResults = futures.map(_.get)
+ for (result <- taskResults)
+ Accumulators.add(currentThread, result.accumUpdates)
+ return taskResults.map(_.value).toArray
+ }
+
+ override def stop() {}
+}
+
+// A ThreadFactory that creates daemon threads
+private object DaemonThreadFactory extends ThreadFactory {
+ override def newThread(r: Runnable): Thread = {
+ val t = new Thread(r);
+ t.setDaemon(true)
+ return t
+ }
+}
diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala
new file mode 100644
index 0000000000..a96fca9350
--- /dev/null
+++ b/src/scala/spark/NexusScheduler.scala
@@ -0,0 +1,258 @@
+package spark
+
+import java.io.File
+import java.util.concurrent.Semaphore
+
+import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus}
+import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver}
+import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
+
+// The main Scheduler implementation, which talks to Nexus. Clients are expected
+// to first call start(), then submit tasks through the runTasks method.
+//
+// This implementation is currently a little quick and dirty. The following
+// improvements need to be made to it:
+// 1) Fault tolerance should be added - if a task fails, just re-run it anywhere.
+// 2) Right now, the scheduler uses a linear scan through the tasks to find a
+// local one for a given node. It would be faster to have a separate list of
+// pending tasks for each node.
+// 3) The Callbacks way of organizing things didn't work out too well, so the
+// way the scheduler keeps track of the currently active runTasks operation
+// can be made cleaner.
+private class NexusScheduler(
+ master: String, frameworkName: String, execArg: Array[Byte])
+extends nexus.Scheduler with spark.Scheduler
+{
+ // Semaphore used by runTasks to ensure only one thread can be in it
+ val semaphore = new Semaphore(1)
+
+ // Lock used to wait for scheduler to be registered
+ var isRegistered = false
+ val registeredLock = new Object()
+
+ // Trait representing a set of scheduler callbacks
+ trait Callbacks {
+ def slotOffer(s: SlaveOffer): Option[TaskDescription]
+ def taskFinished(t: TaskStatus): Unit
+ def error(code: Int, message: String): Unit
+ }
+
+ // Current callback object (may be null)
+ var callbacks: Callbacks = null
+
+ // Incrementing task ID
+ var nextTaskId = 0
+
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
+
+ // Driver for talking to Nexus
+ var driver: SchedulerDriver = null
+
+ override def start() {
+ new Thread("Spark scheduler") {
+ setDaemon(true)
+ override def run {
+ val ns = NexusScheduler.this
+ ns.driver = new NexusSchedulerDriver(ns, master)
+ ns.driver.run()
+ }
+ }.start
+ }
+
+ override def getFrameworkName(d: SchedulerDriver): String = frameworkName
+
+ override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo =
+ new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
+
+ override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
+ val results = new Array[T](tasks.length)
+ if (tasks.length == 0)
+ return results
+
+ val launched = new Array[Boolean](tasks.length)
+
+ val callingThread = currentThread
+
+ var errorHappened = false
+ var errorCode = 0
+ var errorMessage = ""
+
+ // Wait for scheduler to be registered with Nexus
+ waitForRegister()
+
+ try {
+ // Acquire the runTasks semaphore
+ semaphore.acquire()
+
+ val myCallbacks = new Callbacks {
+ val firstTaskId = nextTaskId
+ var tasksLaunched = 0
+ var tasksFinished = 0
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ def slotOffer(slot: SlaveOffer): Option[TaskDescription] = {
+ try {
+ if (tasksLaunched < tasks.length) {
+ // TODO: Add a short wait if no task with location pref is found
+ // TODO: Figure out why a function is needed around this to
+ // avoid scala.runtime.NonLocalReturnException
+ def findTask: Option[TaskDescription] = {
+ var checkPrefVals: Array[Boolean] = Array(true)
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
+ checkPrefVals = Array(true, false) // Allow non-preferred tasks
+ // TODO: Make desiredCpus and desiredMem configurable
+ val desiredCpus = 1
+ val desiredMem = 750L * 1024L * 1024L
+ if (slot.getParams.get("cpus").toInt < desiredCpus ||
+ slot.getParams.get("mem").toLong < desiredMem)
+ return None
+ for (checkPref <- checkPrefVals;
+ i <- 0 until tasks.length;
+ if !launched(i) && (!checkPref || tasks(i).prefers(slot)))
+ {
+ val taskId = nextTaskId
+ nextTaskId += 1
+ printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
+ i, taskId, slot.getSlaveId, slot.getHost,
+ if(checkPref) "preferred" else "non-preferred")
+ tasks(i).markStarted(slot)
+ launched(i) = true
+ tasksLaunched += 1
+ if (checkPref)
+ lastPreferredLaunchTime = time
+ val params = new StringMap
+ params.set("cpus", "" + desiredCpus)
+ params.set("mem", "" + desiredMem)
+ val serializedTask = Utils.serialize(tasks(i))
+ return Some(new TaskDescription(taskId, slot.getSlaveId,
+ "task_" + taskId, params, serializedTask))
+ }
+ return None
+ }
+ return findTask
+ } else {
+ return None
+ }
+ } catch {
+ case e: Exception => {
+ e.printStackTrace
+ System.exit(1)
+ return None
+ }
+ }
+ }
+
+ def taskFinished(status: TaskStatus) {
+ println("Finished TID " + status.getTaskId)
+ // Deserialize task result
+ val result = Utils.deserialize[TaskResult[T]](status.getData)
+ results(status.getTaskId - firstTaskId) = result.value
+ // Update accumulators
+ Accumulators.add(callingThread, result.accumUpdates)
+ // Stop if we've finished all the tasks
+ tasksFinished += 1
+ if (tasksFinished == tasks.length) {
+ NexusScheduler.this.callbacks = null
+ NexusScheduler.this.notifyAll()
+ }
+ }
+
+ def error(code: Int, message: String) {
+ // Save the error message
+ errorHappened = true
+ errorCode = code
+ errorMessage = message
+ // Indicate to caller thread that we're done
+ NexusScheduler.this.callbacks = null
+ NexusScheduler.this.notifyAll()
+ }
+ }
+
+ this.synchronized {
+ this.callbacks = myCallbacks
+ }
+ driver.reviveOffers();
+ this.synchronized {
+ while (this.callbacks != null) this.wait()
+ }
+ } finally {
+ semaphore.release()
+ }
+
+ if (errorHappened)
+ throw new SparkException(errorMessage, errorCode)
+ else
+ return results
+ }
+
+ override def registered(d: SchedulerDriver, frameworkId: Int) {
+ println("Registered as framework ID " + frameworkId)
+ registeredLock.synchronized {
+ isRegistered = true
+ registeredLock.notifyAll()
+ }
+ }
+
+ override def waitForRegister() {
+ registeredLock.synchronized {
+ while (!isRegistered) registeredLock.wait()
+ }
+ }
+
+ override def resourceOffer(
+ d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) {
+ synchronized {
+ val tasks = new TaskDescriptionVector
+ if (callbacks != null) {
+ try {
+ for (i <- 0 until slots.size.toInt) {
+ callbacks.slotOffer(slots.get(i)) match {
+ case Some(task) => tasks.add(task)
+ case None => {}
+ }
+ }
+ } catch {
+ case e: Exception => e.printStackTrace
+ }
+ }
+ val params = new StringMap
+ params.set("timeout", "1")
+ d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
+ }
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ synchronized {
+ if (callbacks != null && status.getState == TaskState.TASK_FINISHED) {
+ try {
+ callbacks.taskFinished(status)
+ } catch {
+ case e: Exception => e.printStackTrace
+ }
+ }
+ }
+ }
+
+ override def error(d: SchedulerDriver, code: Int, message: String) {
+ synchronized {
+ if (callbacks != null) {
+ try {
+ callbacks.error(code, message)
+ } catch {
+ case e: Exception => e.printStackTrace
+ }
+ } else {
+ val msg = "Nexus error: %s (error code: %d)".format(message, code)
+ System.err.println(msg)
+ System.exit(1)
+ }
+ }
+ }
+
+ override def stop() {
+ if (driver != null)
+ driver.stop()
+ }
+}
diff --git a/src/scala/spark/ParallelArray.scala b/src/scala/spark/ParallelArray.scala
new file mode 100644
index 0000000000..90cacf47fc
--- /dev/null
+++ b/src/scala/spark/ParallelArray.scala
@@ -0,0 +1,97 @@
+package spark
+
+abstract class ParallelArray[T](sc: SparkContext) {
+ def filter(f: T => Boolean): ParallelArray[T] = {
+ val cleanF = sc.clean(f)
+ new FilteredParallelArray[T](sc, this, cleanF)
+ }
+
+ def foreach(f: T => Unit): Unit
+
+ def map[U](f: T => U): Array[U]
+}
+
+private object ParallelArray {
+ def slice[T](seq: Seq[T], numSlices: Int): Array[Seq[T]] = {
+ if (numSlices < 1)
+ throw new IllegalArgumentException("Positive number of slices required")
+ seq match {
+ case r: Range.Inclusive => {
+ val sign = if (r.step < 0) -1 else 1
+ slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]],
+ numSlices)
+ }
+ case r: Range => {
+ (0 until numSlices).map(i => {
+ val start = ((i * r.length.toLong) / numSlices).toInt
+ val end = (((i+1) * r.length.toLong) / numSlices).toInt
+ new SerializableRange(
+ r.start + start * r.step, r.start + end * r.step, r.step)
+ }).asInstanceOf[Seq[Seq[T]]].toArray
+ }
+ case _ => {
+ val array = seq.toArray // To prevent O(n^2) operations for List etc
+ (0 until numSlices).map(i => {
+ val start = ((i * array.length.toLong) / numSlices).toInt
+ val end = (((i+1) * array.length.toLong) / numSlices).toInt
+ array.slice(start, end).toArray
+ }).toArray
+ }
+ }
+ }
+}
+
+private class SimpleParallelArray[T](
+ sc: SparkContext, data: Seq[T], numSlices: Int)
+extends ParallelArray[T](sc) {
+ val slices = ParallelArray.slice(data, numSlices)
+
+ def foreach(f: T => Unit) {
+ val cleanF = sc.clean(f)
+ var tasks = for (i <- 0 until numSlices) yield
+ new ForeachRunner(i, slices(i), cleanF)
+ sc.runTasks[Unit](tasks.toArray)
+ }
+
+ def map[U](f: T => U): Array[U] = {
+ val cleanF = sc.clean(f)
+ var tasks = for (i <- 0 until numSlices) yield
+ new MapRunner(i, slices(i), cleanF)
+ return Array.concat(sc.runTasks[Array[U]](tasks.toArray): _*)
+ }
+}
+
+@serializable
+private class ForeachRunner[T](sliceNum: Int, data: Seq[T], f: T => Unit)
+extends Function0[Unit] {
+ def apply() = {
+ printf("Running slice %d of parallel foreach\n", sliceNum)
+ data.foreach(f)
+ }
+}
+
+@serializable
+private class MapRunner[T, U](sliceNum: Int, data: Seq[T], f: T => U)
+extends Function0[Array[U]] {
+ def apply(): Array[U] = {
+ printf("Running slice %d of parallel map\n", sliceNum)
+ return data.map(f).toArray
+ }
+}
+
+private class FilteredParallelArray[T](
+ sc: SparkContext, array: ParallelArray[T], predicate: T => Boolean)
+extends ParallelArray[T](sc) {
+ val cleanPred = sc.clean(predicate)
+
+ def foreach(f: T => Unit) {
+ val cleanF = sc.clean(f)
+ array.foreach(t => if (cleanPred(t)) cleanF(t))
+ }
+
+ def map[U](f: T => U): Array[U] = {
+ val cleanF = sc.clean(f)
+ throw new UnsupportedOperationException(
+ "Map is not yet supported on FilteredParallelArray")
+ }
+}
diff --git a/src/scala/spark/Scheduler.scala b/src/scala/spark/Scheduler.scala
new file mode 100644
index 0000000000..77446d3e78
--- /dev/null
+++ b/src/scala/spark/Scheduler.scala
@@ -0,0 +1,9 @@
+package spark
+
+// Scheduler trait, implemented by both NexusScheduler and LocalScheduler.
+private trait Scheduler {
+ def start()
+ def waitForRegister()
+ def runTasks[T](tasks: Array[Task[T]]): Array[T]
+ def stop()
+}
diff --git a/src/scala/spark/SerializableRange.scala b/src/scala/spark/SerializableRange.scala
new file mode 100644
index 0000000000..5d383a40dc
--- /dev/null
+++ b/src/scala/spark/SerializableRange.scala
@@ -0,0 +1,75 @@
+// This is a copy of Scala 2.7.7's Range class, (c) 2006-2009, LAMP/EPFL.
+// The only change here is to make it Serializable, because Ranges aren't.
+// This won't be needed in Scala 2.8, where Scala's Range becomes Serializable.
+
+package spark
+
+@serializable
+private class SerializableRange(val start: Int, val end: Int, val step: Int)
+extends RandomAccessSeq.Projection[Int] {
+ if (step == 0) throw new Predef.IllegalArgumentException
+
+ /** Create a new range with the start and end values of this range and
+ * a new <code>step</code>.
+ */
+ def by(step: Int): Range = new Range(start, end, step)
+
+ override def foreach(f: Int => Unit) {
+ if (step > 0) {
+ var i = this.start
+ val until = if (inInterval(end)) end + 1 else end
+
+ while (i < until) {
+ f(i)
+ i += step
+ }
+ } else {
+ var i = this.start
+ val until = if (inInterval(end)) end - 1 else end
+
+ while (i > until) {
+ f(i)
+ i += step
+ }
+ }
+ }
+
+ lazy val length: Int = {
+ if (start < end && this.step < 0) 0
+ else if (start > end && this.step > 0) 0
+ else {
+ val base = if (start < end) end - start
+ else start - end
+ assert(base >= 0)
+ val step = if (this.step < 0) -this.step else this.step
+ assert(step >= 0)
+ base / step + last(base, step)
+ }
+ }
+
+ protected def last(base: Int, step: Int): Int =
+ if (base % step != 0) 1 else 0
+
+ def apply(idx: Int): Int = {
+ if (idx < 0 || idx >= length) throw new Predef.IndexOutOfBoundsException
+ start + (step * idx)
+ }
+
+ /** a <code>Seq.contains</code>, not a <code>Iterator.contains</code>! */
+ def contains(x: Int): Boolean = {
+ inInterval(x) && (((x - start) % step) == 0)
+ }
+
+ /** Is the argument inside the interval defined by `start' and `end'?
+ * Returns true if `x' is inside [start, end).
+ */
+ protected def inInterval(x: Int): Boolean =
+ if (step > 0)
+ (x >= start && x < end)
+ else
+ (x <= start && x > end)
+
+ //def inclusive = new Range.Inclusive(start,end,step)
+
+ override def toString = "SerializableRange(%d, %d, %d)".format(start, end, step)
+}
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
new file mode 100644
index 0000000000..4bfbcb6f21
--- /dev/null
+++ b/src/scala/spark/SparkContext.scala
@@ -0,0 +1,89 @@
+package spark
+
+import java.io._
+import java.util.UUID
+
+import scala.collection.mutable.ArrayBuffer
+
+class SparkContext(master: String, frameworkName: String) {
+ Cache.initialize()
+
+ def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] =
+ new SimpleParallelArray[T](this, seq, numSlices)
+
+ def parallelize[T](seq: Seq[T]): ParallelArray[T] = parallelize(seq, 2)
+
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
+ new Accumulator(initialValue, param)
+
+ // TODO: Keep around a weak hash map of values to Cached versions?
+ def broadcast[T](value: T) = new Cached(value, local)
+
+ def textFile(path: String) = new HdfsTextFile(this, path)
+
+ val LOCAL_REGEX = """local\[([0-9]+)\]""".r
+
+ private var scheduler: Scheduler = master match {
+ case "local" => new LocalScheduler(1)
+ case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt)
+ case _ => { System.loadLibrary("nexus");
+ new NexusScheduler(master, frameworkName, createExecArg()) }
+ }
+
+ private val local = scheduler.isInstanceOf[LocalScheduler]
+
+ scheduler.start()
+
+ private def createExecArg(): Array[Byte] = {
+ // Our executor arg is an array containing all the spark.* system properties
+ val props = new ArrayBuffer[(String, String)]
+ val iter = System.getProperties.entrySet.iterator
+ while (iter.hasNext) {
+ val entry = iter.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark."))
+ props += (key, value)
+ }
+ return Utils.serialize(props.toArray)
+ }
+
+ def runTasks[T](tasks: Array[() => T]): Array[T] = {
+ runTaskObjects(tasks.map(f => new FunctionTask(f)))
+ }
+
+ private[spark] def runTaskObjects[T](tasks: Seq[Task[T]]): Array[T] = {
+ println("Running " + tasks.length + " tasks in parallel")
+ val start = System.nanoTime
+ val result = scheduler.runTasks(tasks.toArray)
+ println("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
+ return result
+ }
+
+ def stop() {
+ scheduler.stop()
+ scheduler = null
+ }
+
+ def waitForRegister() {
+ scheduler.waitForRegister()
+ }
+
+ // Clean a closure to make it ready to serialized and send to tasks
+ // (removes unreferenced variables in $outer's, updates REPL variables)
+ private[spark] def clean[F <: AnyRef](f: F): F = {
+ ClosureCleaner.clean(f)
+ return f
+ }
+}
+
+object SparkContext {
+ implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ def add(t1: Double, t2: Double): Double = t1 + t2
+ def zero(initialValue: Double) = 0.0
+ }
+ implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ def add(t1: Int, t2: Int): Int = t1 + t2
+ def zero(initialValue: Int) = 0
+ }
+ // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+}
diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala
new file mode 100644
index 0000000000..7257bf7b0c
--- /dev/null
+++ b/src/scala/spark/SparkException.scala
@@ -0,0 +1,7 @@
+package spark
+
+class SparkException(message: String) extends Exception(message) {
+ def this(message: String, errorCode: Int) {
+ this("%s (error code: %d)".format(message, errorCode))
+ }
+}
diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala
new file mode 100644
index 0000000000..e559996a37
--- /dev/null
+++ b/src/scala/spark/Task.scala
@@ -0,0 +1,16 @@
+package spark
+
+import nexus._
+
+@serializable
+trait Task[T] {
+ def run: T
+ def prefers(slot: SlaveOffer): Boolean = true
+ def markStarted(slot: SlaveOffer) {}
+}
+
+@serializable
+class FunctionTask[T](body: () => T)
+extends Task[T] {
+ def run: T = body()
+}
diff --git a/src/scala/spark/TaskResult.scala b/src/scala/spark/TaskResult.scala
new file mode 100644
index 0000000000..db33c9ff44
--- /dev/null
+++ b/src/scala/spark/TaskResult.scala
@@ -0,0 +1,9 @@
+package spark
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+@serializable
+private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any])
diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala
new file mode 100644
index 0000000000..52bcb89f00
--- /dev/null
+++ b/src/scala/spark/Utils.scala
@@ -0,0 +1,28 @@
+package spark
+
+import java.io._
+
+private object Utils {
+ def serialize[T](o: T): Array[Byte] = {
+ val bos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(bos)
+ oos.writeObject(o)
+ oos.close
+ return bos.toByteArray
+ }
+
+ def deserialize[T](bytes: Array[Byte]): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis)
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+}
diff --git a/src/scala/spark/repl/ExecutorClassLoader.scala b/src/scala/spark/repl/ExecutorClassLoader.scala
new file mode 100644
index 0000000000..7d91b20e79
--- /dev/null
+++ b/src/scala/spark/repl/ExecutorClassLoader.scala
@@ -0,0 +1,86 @@
+package spark.repl
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.net.{URI, URL, URLClassLoader}
+import java.util.concurrent.{Executors, ExecutorService}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.objectweb.asm._
+import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.Opcodes._
+
+
+// A ClassLoader that reads classes from a Hadoop FileSystem URL, used to load
+// classes defined by the interpreter when the REPL is in use
+class ExecutorClassLoader(classDir: String, parent: ClassLoader)
+extends ClassLoader(parent) {
+ val fileSystem = FileSystem.get(new URI(classDir), new Configuration())
+ val directory = new URI(classDir).getPath
+
+ override def findClass(name: String): Class[_] = {
+ try {
+ //println("repl.ExecutorClassLoader resolving " + name)
+ val path = new Path(directory, name.replace('.', '/') + ".class")
+ val bytes = readAndTransformClass(name, fileSystem.open(path))
+ return defineClass(name, bytes, 0, bytes.length)
+ } catch {
+ case e: Exception => throw new ClassNotFoundException(name, e)
+ }
+ }
+
+ def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
+ if (name.startsWith("line") && name.endsWith("$iw$")) {
+ // Class seems to be an interpreter "wrapper" object storing a val or var.
+ // Replace its constructor with a dummy one that does not run the
+ // initialization code placed there by the REPL. The val or var will
+ // be initialized later through reflection when it is used in a task.
+ val cr = new ClassReader(in)
+ val cw = new ClassWriter(
+ ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
+ val cleaner = new ConstructorCleaner(name, cw)
+ cr.accept(cleaner, 0)
+ return cw.toByteArray
+ } else {
+ // Pass the class through unmodified
+ val bos = new ByteArrayOutputStream
+ val bytes = new Array[Byte](4096)
+ var done = false
+ while (!done) {
+ val num = in.read(bytes)
+ if (num >= 0)
+ bos.write(bytes, 0, num)
+ else
+ done = true
+ }
+ return bos.toByteArray
+ }
+ }
+}
+
+class ConstructorCleaner(className: String, cv: ClassVisitor)
+extends ClassAdapter(cv) {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ val mv = cv.visitMethod(access, name, desc, sig, exceptions)
+ if (name == "<init>" && (access & ACC_STATIC) == 0) {
+ // This is the constructor, time to clean it; just output some new
+ // instructions to mv that create the object and set the static MODULE$
+ // field in the class to point to it, but do nothing otherwise.
+ //println("Cleaning constructor of " + className)
+ mv.visitCode()
+ mv.visitVarInsn(ALOAD, 0) // load this
+ mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
+ mv.visitVarInsn(ALOAD, 0) // load this
+ //val classType = className.replace('.', '/')
+ //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
+ mv.visitInsn(RETURN)
+ mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
+ mv.visitEnd()
+ return null
+ } else {
+ return mv
+ }
+ }
+}
diff --git a/src/scala/spark/repl/Main.scala b/src/scala/spark/repl/Main.scala
new file mode 100644
index 0000000000..f00df5aa58
--- /dev/null
+++ b/src/scala/spark/repl/Main.scala
@@ -0,0 +1,16 @@
+package spark.repl
+
+import scala.collection.mutable.Set
+
+object Main {
+ private var _interp: SparkInterpreterLoop = null
+
+ def interp = _interp
+
+ private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i }
+
+ def main(args: Array[String]) {
+ _interp = new SparkInterpreterLoop
+ _interp.main(args)
+ }
+}
diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala
new file mode 100644
index 0000000000..2377f0c7d6
--- /dev/null
+++ b/src/scala/spark/repl/SparkInterpreter.scala
@@ -0,0 +1,1004 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2009 LAMP/EPFL
+ * @author Martin Odersky
+ */
+// $Id: Interpreter.scala 17013 2009-02-02 11:59:53Z washburn $
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import java.io.{File, IOException, PrintWriter, StringWriter, Writer}
+import java.lang.{Class, ClassLoader}
+import java.net.{MalformedURLException, URL, URLClassLoader}
+import java.util.UUID
+
+import scala.collection.immutable.ListSet
+import scala.collection.mutable
+import scala.collection.mutable.{ListBuffer, HashSet, ArrayBuffer}
+
+//import ast.parser.SyntaxAnalyzer
+import io.{PlainFile, VirtualDirectory}
+import reporters.{ConsoleReporter, Reporter}
+import symtab.Flags
+import util.{SourceFile,BatchSourceFile,ClassPath,NameTransformer}
+import nsc.{InterpreterResults=>IR}
+import scala.tools.nsc.interpreter._
+
+/** <p>
+ * An interpreter for Scala code.
+ * </p>
+ * <p>
+ * The main public entry points are <code>compile()</code>,
+ * <code>interpret()</code>, and <code>bind()</code>.
+ * The <code>compile()</code> method loads a
+ * complete Scala file. The <code>interpret()</code> method executes one
+ * line of Scala code at the request of the user. The <code>bind()</code>
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ * </p>
+ * <p>
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ * </p>
+ * <p>
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports a single member named "result". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ * </p>
+ * <p>
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ * </p>
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+class SparkInterpreter(val settings: Settings, out: PrintWriter) {
+ import symtab.Names
+
+ /* If the interpreter is running on pre-jvm-1.5 JVM,
+ it is necessary to force the target setting to jvm-1.4 */
+ private val major = System.getProperty("java.class.version").split("\\.")(0)
+ if (major.toInt < 49) {
+ this.settings.target.value = "jvm-1.4"
+ }
+
+ /** directory to save .class files to */
+ //val virtualDirectory = new VirtualDirectory("(memory)", None)
+ val virtualDirectory = {
+ val tmpDir = new File(System.getProperty("java.io.tmpdir"))
+ var attempts = 0
+ val maxAttempts = 10
+ var outputDir: File = null
+ while (outputDir == null) {
+ attempts += 1
+ if (attempts > maxAttempts) {
+ throw new IOException("Failed to create a temp directory " +
+ "after " + maxAttempts + " attempts!")
+ }
+ try {
+ outputDir = new File(tmpDir, "spark-" + UUID.randomUUID.toString)
+ if (outputDir.exists() || !outputDir.mkdirs())
+ outputDir = null
+ } catch { case e: IOException => ; }
+ }
+ System.setProperty("spark.repl.classdir",
+ "file://" + outputDir.getAbsolutePath + "/")
+ //println("Output dir: " + outputDir)
+ new PlainFile(outputDir)
+ }
+
+ /** the compiler to compile expressions with */
+ val compiler: scala.tools.nsc.Global = newCompiler(settings, reporter)
+
+ import compiler.Traverser
+ import compiler.{Tree, TermTree,
+ ValOrDefDef, ValDef, DefDef, Assign,
+ ClassDef, ModuleDef, Ident, Select, TypeDef,
+ Import, MemberDef, DocDef}
+ import compiler.CompilationUnit
+ import compiler.{Symbol,Name,Type}
+ import compiler.nme
+ import compiler.newTermName
+ import compiler.newTypeName
+ import compiler.nme.{INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX}
+ import Interpreter.string2code
+
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings) =
+ this(settings,
+ new NewLinePrintWriter(new ConsoleWriter, true))
+
+ /** whether to print out result lines */
+ private var printResults: Boolean = true
+
+ /** Be quiet. Do not print out the results of each
+ * submitted command unless an exception is thrown. */
+ def beQuiet = { printResults = false }
+
+ /** Temporarily be quiet */
+ def beQuietDuring[T](operation: => T): T = {
+ val wasPrinting = printResults
+ try {
+ printResults = false
+ operation
+ } finally {
+ printResults = wasPrinting
+ }
+ }
+
+ /** interpreter settings */
+ lazy val isettings = new InterpreterSettings
+
+ object reporter extends ConsoleReporter(settings, null, out) {
+ //override def printMessage(msg: String) { out.println(clean(msg)) }
+ override def printMessage(msg: String) { out.print(clean(msg) + "\n"); out.flush() }
+ }
+
+ /** Instantiate a compiler. Subclasses can override this to
+ * change the compiler class used by this interpreter. */
+ protected def newCompiler(settings: Settings, reporter: Reporter) = {
+ val comp = new scala.tools.nsc.Global(settings, reporter)
+ comp.genJVM.outputDir = virtualDirectory
+ comp
+ }
+
+
+ /** the compiler's classpath, as URL's */
+ val compilerClasspath: List[URL] = {
+ val classpathPart =
+ (ClassPath.expandPath(compiler.settings.classpath.value).
+ map(s => new File(s).toURL))
+ def parseURL(s: String): Option[URL] =
+ try { Some(new URL(s)) }
+ catch { case _:MalformedURLException => None }
+ val codebasePart = (compiler.settings.Xcodebase.value.split(" ")).toList.flatMap(parseURL)
+ classpathPart ::: codebasePart
+ }
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
+ It would also be possible to create a new class loader for each command
+ to interpret. The advantages of the current approach are:
+
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
+
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ /** class loader used to load compiled code */
+ private val classLoader = {
+ val parent =
+ if (parentClassLoader == null)
+ new URLClassLoader(compilerClasspath.toArray)
+ else
+ new URLClassLoader(compilerClasspath.toArray,
+ parentClassLoader)
+ val virtualDirUrl = new URL("file://" + virtualDirectory.path + "/")
+ new URLClassLoader(Array(virtualDirUrl), parent)
+ //new InterpreterClassLoader(Array(virtualDirUrl), parent)
+ //new AbstractFileClassLoader(virtualDirectory, parent)
+ }
+
+ /** Set the current Java "context" class loader to this
+ * interpreter's class loader */
+ def setContextClassLoader() {
+ Thread.currentThread.setContextClassLoader(classLoader)
+ }
+
+ protected def parentClassLoader: ClassLoader = this.getClass.getClassLoader()
+
+ /** the previous requests this interpreter has processed */
+ private val prevRequests = new ArrayBuffer[Request]()
+
+ /** next line number to use */
+ private var nextLineNo = 0
+
+ /** allocate a fresh line name */
+ private def newLineName = {
+ val num = nextLineNo
+ nextLineNo += 1
+ compiler.nme.INTERPRETER_LINE_PREFIX + num
+ }
+
+ /** next result variable number to use */
+ private var nextVarNameNo = 0
+
+ /** allocate a fresh variable name */
+ private def newVarName() = {
+ val num = nextVarNameNo
+ nextVarNameNo += 1
+ INTERPRETER_VAR_PREFIX + num
+ }
+
+ /** next internal variable number to use */
+ private var nextInternalVarNo = 0
+
+ /** allocate a fresh internal variable name */
+ private def newInternalVarName() = {
+ val num = nextVarNameNo
+ nextVarNameNo += 1
+ INTERPRETER_SYNTHVAR_PREFIX + num
+ }
+
+
+ /** Check if a name looks like it was generated by newVarName */
+ private def isGeneratedVarName(name: String): Boolean =
+ name.startsWith(INTERPRETER_VAR_PREFIX) && {
+ val suffix = name.drop(INTERPRETER_VAR_PREFIX.length)
+ suffix.forall(_.isDigit)
+ }
+
+
+ /** generate a string using a routine that wants to write on a stream */
+ private def stringFrom(writer: PrintWriter => Unit): String = {
+ val stringWriter = new StringWriter()
+ val stream = new NewLinePrintWriter(stringWriter)
+ writer(stream)
+ stream.close
+ stringWriter.toString
+ }
+
+ /** Truncate a string if it is longer than settings.maxPrintString */
+ private def truncPrintString(str: String): String = {
+ val maxpr = isettings.maxPrintString
+
+ if (maxpr <= 0)
+ return str
+
+ if (str.length <= maxpr)
+ return str
+
+ val trailer = "..."
+ if (maxpr >= trailer.length+1)
+ return str.substring(0, maxpr-3) + trailer
+
+ str.substring(0, maxpr)
+ }
+
+ /** Clean up a string for output */
+ private def clean(str: String) =
+ truncPrintString(Interpreter.stripWrapperGunk(str))
+
+ /** Indent some code by the width of the scala> prompt.
+ * This way, compiler error messages read beettr.
+ */
+ def indentCode(code: String) = {
+ val spaces = " "
+
+ stringFrom(str =>
+ for (line <- code.lines) {
+ str.print(spaces)
+ str.print(line + "\n")
+ str.flush()
+ })
+ }
+
+ implicit def name2string(name: Name) = name.toString
+
+ /** Compute imports that allow definitions from previous
+ * requests to be visible in a new request. Returns
+ * three pieces of related code:
+ *
+ * 1. An initial code fragment that should go before
+ * the code of the new request.
+ *
+ * 2. A code fragment that should go after the code
+ * of the new request.
+ *
+ * 3. An access path which can be traverested to access
+ * any bindings inside code wrapped by #1 and #2 .
+ *
+ * The argument is a set of Names that need to be imported.
+ *
+ * Limitations: This method is not as precise as it could be.
+ * (1) It does not process wildcard imports to see what exactly
+ * they import.
+ * (2) If it imports any names from a request, it imports all
+ * of them, which is not really necessary.
+ * (3) It imports multiple same-named implicits, but only the
+ * last one imported is actually usable.
+ */
+ private def importsCode(wanted: Set[Name]): (String, String, String) = {
+ /** Narrow down the list of requests from which imports
+ * should be taken. Removes requests which cannot contribute
+ * useful imports for the specified set of wanted names.
+ */
+ def reqsToUse: List[(Request,MemberHandler)] = {
+ /** Loop through a list of MemberHandlers and select
+ * which ones to keep. 'wanted' is the set of
+ * names that need to be imported, and
+ * 'shadowed' is the list of names useless to import
+ * because a later request will re-import it anyway.
+ */
+ def select(reqs: List[(Request,MemberHandler)], wanted: Set[Name]):
+ List[(Request,MemberHandler)] = {
+ reqs match {
+ case Nil => Nil
+
+ case (req,handler)::rest =>
+ val keepit =
+ (handler.definesImplicit ||
+ handler.importsWildcard ||
+ handler.importedNames.exists(wanted.contains(_)) ||
+ handler.boundNames.exists(wanted.contains(_)))
+
+ val newWanted =
+ if (keepit) {
+ (wanted
+ ++ handler.usedNames
+ -- handler.boundNames
+ -- handler.importedNames)
+ } else {
+ wanted
+ }
+
+ val restToKeep = select(rest, newWanted)
+
+ if(keepit)
+ (req,handler) :: restToKeep
+ else
+ restToKeep
+ }
+ }
+
+ val rhpairs = for {
+ req <- prevRequests.toList.reverse
+ handler <- req.handlers
+ } yield (req, handler)
+
+ select(rhpairs, wanted).reverse
+ }
+
+ val code = new StringBuffer
+ val trailingLines = new ArrayBuffer[String]
+ val accessPath = new StringBuffer
+ val impname = compiler.nme.INTERPRETER_IMPORT_WRAPPER
+ val currentImps = mutable.Set.empty[Name]
+
+ // add code for a new object to hold some imports
+ /*def addWrapper() {
+ code.append("object " + impname + "{\n")
+ trailingLines.append("}\n")
+ accessPath.append("." + impname)
+ currentImps.clear
+ }*/
+ def addWrapper() {
+ code.append("@serializable class " + impname + "C {\n")
+ trailingLines.append("}\nval " + impname + " = new " + impname + "C;\n")
+ accessPath.append("." + impname)
+ currentImps.clear
+ }
+
+ addWrapper()
+
+ // loop through previous requests, adding imports
+ // for each one
+ for ((req,handler) <- reqsToUse) {
+ // If the user entered an import, then just use it
+
+ // add an import wrapping level if the import might
+ // conflict with some other import
+ if(handler.importsWildcard ||
+ currentImps.exists(handler.importedNames.contains))
+ if(!currentImps.isEmpty)
+ addWrapper()
+
+ if (handler.member.isInstanceOf[Import])
+ code.append(handler.member.toString + ";\n")
+
+ // give wildcard imports a import wrapper all to their own
+ if(handler.importsWildcard)
+ addWrapper()
+ else
+ currentImps ++= handler.importedNames
+
+ // For other requests, import each bound variable.
+ // import them explicitly instead of with _, so that
+ // ambiguity errors will not be generated. Also, quote
+ // the name of the variable, so that we don't need to
+ // handle quoting keywords separately.
+ for (imv <- handler.boundNames) {
+ if (currentImps.contains(imv))
+ addWrapper()
+ code.append("val " + req.objectName + "$VAL = " + req.objectName + ".INSTANCE;\n")
+ code.append("import ")
+ code.append(req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n")
+ // The code below is less likely to pull in bad variables, but prevents use of vars & classes
+ //code.append("val `" + imv + "` = " + req.objectName + ".INSTANCE" + req.accessPath + ".`" + imv + "`;\n")
+ currentImps += imv
+ }
+ }
+
+ addWrapper() // Add one extra wrapper, to prevent warnings
+ // in the frequent case of redefining
+ // the value bound in the last interpreter
+ // request.
+
+ (code.toString, trailingLines.reverse.mkString, accessPath.toString)
+ }
+
+ /** Parse a line into a sequence of trees. Returns None if the input
+ * is incomplete. */
+ private def parse(line: String): Option[List[Tree]] = {
+ var justNeedsMore = false
+ reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) {
+ // simple parse: just parse it, nothing else
+ def simpleParse(code: String): List[Tree] = {
+ reporter.reset
+ val unit =
+ new CompilationUnit(
+ new BatchSourceFile("<console>", code.toCharArray()))
+ val scanner = new compiler.syntaxAnalyzer.UnitParser(unit);
+ val xxx = scanner.templateStatSeq(false);
+ (xxx._2)
+ }
+ val (trees) = simpleParse(line)
+ if (reporter.hasErrors) {
+ Some(Nil) // the result did not parse, so stop
+ } else if (justNeedsMore) {
+ None
+ } else {
+ Some(trees)
+ }
+ }
+ }
+
+ /** Compile an nsc SourceFile. Returns true if there are
+ * no compilation errors, or false othrewise.
+ */
+ def compileSources(sources: List[SourceFile]): Boolean = {
+ val cr = new compiler.Run
+ reporter.reset
+ cr.compileSources(sources)
+ !reporter.hasErrors
+ }
+
+ /** Compile a string. Returns true if there are no
+ * compilation errors, or false otherwise.
+ */
+ def compileString(code: String): Boolean =
+ compileSources(List(new BatchSourceFile("<script>", code.toCharArray)))
+
+ /** Build a request from the user. <code>trees</code> is <code>line</code>
+ * after being parsed.
+ */
+ private def buildRequest(trees: List[Tree], line: String, lineName: String): Request =
+ new Request(line, lineName)
+
+ private def chooseHandler(member: Tree): Option[MemberHandler] =
+ member match {
+ case member: DefDef =>
+ Some(new DefHandler(member))
+ case member: ValDef =>
+ Some(new ValHandler(member))
+ case member@Assign(Ident(_), _) => Some(new AssignHandler(member))
+ case member: ModuleDef => Some(new ModuleHandler(member))
+ case member: ClassDef => Some(new ClassHandler(member))
+ case member: TypeDef => Some(new TypeAliasHandler(member))
+ case member: Import => Some(new ImportHandler(member))
+ case DocDef(_, documented) => chooseHandler(documented)
+ case member => Some(new GenericHandler(member))
+ }
+
+ /** <p>
+ * Interpret one line of input. All feedback, including parse errors
+ * and evaluation results, are printed via the supplied compiler's
+ * reporter. Values defined are available for future interpreted
+ * strings.
+ * </p>
+ * <p>
+ * The return value is whether the line was interpreter successfully,
+ * e.g. that there were no parse errors.
+ * </p>
+ *
+ * @param line ...
+ * @return ...
+ */
+ def interpret(line: String): IR.Result = {
+ if (prevRequests.isEmpty)
+ new compiler.Run // initialize the compiler
+
+ // parse
+ val trees = parse(indentCode(line)) match {
+ case None => return IR.Incomplete
+ case (Some(Nil)) => return IR.Error // parse error or empty input
+ case Some(trees) => trees
+ }
+
+ trees match {
+ case List(_:Assign) => ()
+
+ case List(_:TermTree) | List(_:Ident) | List(_:Select) =>
+ // Treat a single bare expression specially.
+ // This is necessary due to it being hard to modify
+ // code at a textual level, and it being hard to
+ // submit an AST to the compiler.
+ return interpret("val "+newVarName()+" = \n"+line)
+
+ case _ => ()
+ }
+
+ val lineName = newLineName
+
+ // figure out what kind of request
+ val req = buildRequest(trees, line, lineName)
+ if (req eq null) return IR.Error // a disallowed statement type
+
+ if (!req.compile)
+ return IR.Error // an error happened during compilation, e.g. a type error
+
+ val (interpreterResultString, succeeded) = req.loadAndRun
+
+ if (printResults || !succeeded) {
+ // print the result
+ out.print(clean(interpreterResultString))
+ }
+
+ // book-keeping
+ if (succeeded)
+ prevRequests += req
+
+ if (succeeded) IR.Success else IR.Error
+ }
+
+ /** A counter used for numbering objects created by <code>bind()</code>. */
+ private var binderNum = 0
+
+ /** Bind a specified name to a specified value. The name may
+ * later be used by expressions passed to interpret.
+ *
+ * @param name the variable name to bind
+ * @param boundType the type of the variable, as a string
+ * @param value the object value to bind to it
+ * @return an indication of whether the binding succeeded
+ */
+ def bind(name: String, boundType: String, value: Any): IR.Result = {
+ val binderName = "binder" + binderNum
+ binderNum += 1
+
+ compileString(
+ "object " + binderName +
+ "{ var value: " + boundType + " = _; " +
+ " def set(x: Any) = value=x.asInstanceOf[" + boundType + "]; }")
+
+ val binderObject =
+ Class.forName(binderName, true, classLoader)
+ val setterMethod =
+ (binderObject
+ .getDeclaredMethods
+ .toList
+ .find(meth => meth.getName == "set")
+ .get)
+ var argsHolder: Array[Any] = null // this roundabout approach is to try and
+ // make sure the value is boxed
+ argsHolder = List(value).toArray
+ setterMethod.invoke(null, argsHolder.asInstanceOf[Array[AnyRef]]: _*)
+
+ interpret("val " + name + " = " + binderName + ".value")
+ }
+
+
+ /** <p>
+ * This instance is no longer needed, so release any resources
+ * it is using. The reporter's output gets flushed.
+ * </p>
+ */
+ def close() {
+ reporter.flush()
+ }
+
+ /** A traverser that finds all mentioned identifiers, i.e. things
+ * that need to be imported.
+ * It might return extra names.
+ */
+ private class ImportVarsTraverser(definedVars: List[Name]) extends Traverser {
+ val importVars = new HashSet[Name]()
+
+ override def traverse(ast: Tree) {
+ ast match {
+ case Ident(name) => importVars += name
+ case _ => super.traverse(ast)
+ }
+ }
+ }
+
+
+ /** Class to handle one member among all the members included
+ * in a single interpreter request.
+ */
+ private sealed abstract class MemberHandler(val member: Tree) {
+ val usedNames: List[Name] = {
+ val ivt = new ImportVarsTraverser(boundNames)
+ ivt.traverseTrees(List(member))
+ ivt.importVars.toList
+ }
+ def boundNames: List[Name] = Nil
+ def valAndVarNames: List[Name] = Nil
+ def defNames: List[Name] = Nil
+ val importsWildcard = false
+ val importedNames: Seq[Name] = Nil
+ val definesImplicit =
+ member match {
+ case tree:MemberDef =>
+ tree.mods.hasFlag(symtab.Flags.IMPLICIT)
+ case _ => false
+ }
+
+ def extraCodeToEvaluate(req: Request, code: PrintWriter) { }
+ def resultExtractionCode(req: Request, code: PrintWriter) { }
+ }
+
+ private class GenericHandler(member: Tree) extends MemberHandler(member)
+
+ private class ValHandler(member: ValDef) extends MemberHandler(member) {
+ override lazy val boundNames = List(member.name)
+ override def valAndVarNames = boundNames
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ val vname = member.name
+ if (member.mods.isPublic &&
+ !(isGeneratedVarName(vname) &&
+ req.typeOf(compiler.encode(vname)) == "Unit"))
+ {
+ val prettyName = NameTransformer.decode(vname)
+ code.print(" + \"" + prettyName + ": " +
+ string2code(req.typeOf(vname)) +
+ " = \" + " +
+ " { val tmp = scala.runtime.ScalaRunTime.stringOf(" +
+ req.fullPath(vname) +
+ "); " +
+ " (if(tmp.toSeq.contains('\\n')) \"\\n\" else \"\") + tmp + \"\\n\"} ")
+ }
+ }
+ }
+
+ private class DefHandler(defDef: DefDef) extends MemberHandler(defDef) {
+ override lazy val boundNames = List(defDef.name)
+ override def defNames = boundNames
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ if (defDef.mods.isPublic)
+ code.print("+\""+string2code(defDef.name)+": "+
+ string2code(req.typeOf(defDef.name))+"\\n\"")
+ }
+ }
+
+ private class AssignHandler(member: Assign) extends MemberHandler(member) {
+ val lhs = member. lhs.asInstanceOf[Ident] // an unfortunate limitation
+
+ val helperName = newTermName(newInternalVarName())
+ override val valAndVarNames = List(helperName)
+
+ override def extraCodeToEvaluate(req: Request, code: PrintWriter) {
+ code.println("val "+helperName+" = "+member.lhs+";")
+ }
+
+ /** Print out lhs instead of the generated varName */
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ code.print(" + \"" + lhs + ": " +
+ string2code(req.typeOf(compiler.encode(helperName))) +
+ " = \" + " +
+ string2code(req.fullPath(helperName))
+ + " + \"\\n\"")
+ }
+ }
+
+ private class ModuleHandler(module: ModuleDef) extends MemberHandler(module) {
+ override lazy val boundNames = List(module.name)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ code.println(" + \"defined module " +
+ string2code(module.name)
+ + "\\n\"")
+ }
+ }
+
+ private class ClassHandler(classdef: ClassDef)
+ extends MemberHandler(classdef)
+ {
+ override lazy val boundNames =
+ List(classdef.name) :::
+ (if (classdef.mods.hasFlag(Flags.CASE))
+ List(classdef.name.toTermName)
+ else
+ Nil)
+
+ // TODO: MemberDef.keyword does not include "trait";
+ // otherwise it could be used here
+ def keyword: String =
+ if (classdef.mods.isTrait) "trait" else "class"
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ code.print(
+ " + \"defined " +
+ keyword +
+ " " +
+ string2code(classdef.name) +
+ "\\n\"")
+ }
+ }
+
+ private class TypeAliasHandler(typeDef: TypeDef)
+ extends MemberHandler(typeDef)
+ {
+ override lazy val boundNames =
+ if (typeDef.mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef))
+ List(typeDef.name)
+ else
+ Nil
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ code.println(" + \"defined type alias " +
+ string2code(typeDef.name) + "\\n\"")
+ }
+ }
+
+ private class ImportHandler(imp: Import) extends MemberHandler(imp) {
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ code.println("+ \"" + imp.toString + "\\n\"")
+ }
+
+ /** Whether this import includes a wildcard import */
+ override val importsWildcard =
+ imp.selectors.map(_._1).contains(nme.USCOREkw)
+
+ /** The individual names imported by this statement */
+ override val importedNames: Seq[Name] =
+ for {
+ val (_,sel) <- imp.selectors
+ sel != null
+ sel != nme.USCOREkw
+ val name <- List(sel.toTypeName, sel.toTermName)
+ }
+ yield name
+ }
+
+ /** One line of code submitted by the user for interpretation */
+ private class Request(val line: String, val lineName: String) {
+ val trees = parse(line) match {
+ case Some(ts) => ts
+ case None => Nil
+ }
+
+ /** name to use for the object that will compute "line" */
+ def objectName = lineName + compiler.nme.INTERPRETER_WRAPPER_SUFFIX
+
+ /** name of the object that retrieves the result from the above object */
+ def resultObjectName = "RequestResult$" + objectName
+
+ val handlers: List[MemberHandler] = trees.flatMap(chooseHandler(_))
+
+ /** all (public) names defined by these statements */
+ val boundNames = (ListSet() ++ handlers.flatMap(_.boundNames)).toList
+
+ /** list of names used by this expression */
+ val usedNames: List[Name] = handlers.flatMap(_.usedNames)
+
+ def myImportsCode = importsCode(Set.empty ++ usedNames)
+
+ /** Code to append to objectName to access anything that
+ * the request binds. */
+ val accessPath = myImportsCode._3
+
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: String): String =
+ objectName + ".INSTANCE" + accessPath + ".`" + vname + "`"
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: Name): String = fullPath(vname.toString)
+
+ /** the line of code to compute */
+ def toCompute = line
+
+ /** generate the source code for the object that computes this request */
+ def objectSourceCode: String = {
+ val src = stringFrom { code =>
+ // header for the wrapper object
+ code.println("@serializable class " + objectName + " {")
+
+ val (importsPreamble, importsTrailer, _) = myImportsCode
+
+ code.print(importsPreamble)
+
+ code.println(indentCode(toCompute))
+
+ handlers.foreach(_.extraCodeToEvaluate(this,code))
+
+ code.println(importsTrailer)
+
+ //end the wrapper object
+ code.println(";}")
+
+ //create an object
+ code.println("object " + objectName + " {")
+ code.println(" val INSTANCE = new " + objectName + "();")
+ code.println("}")
+ }
+ //println(src)
+ src
+ }
+
+ /** Types of variables defined by this request. They are computed
+ after compilation of the main object */
+ var typeOf: Map[Name, String] = _
+
+ /** generate source code for the object that retrieves the result
+ from objectSourceCode */
+ def resultObjectSourceCode: String =
+ stringFrom(code => {
+ code.println("object " + resultObjectName)
+ code.println("{ val result: String = {")
+ code.println(objectName + ".INSTANCE" + accessPath + ";") // evaluate the object, to make sure its constructor is run
+ code.print("(\"\"") // print an initial empty string, so later code can
+ // uniformly be: + morestuff
+ handlers.foreach(_.resultExtractionCode(this, code))
+ code.println("\n)}")
+ code.println(";}")
+ })
+
+
+ /** Compile the object file. Returns whether the compilation succeeded.
+ * If all goes well, the "types" map is computed. */
+ def compile(): Boolean = {
+ reporter.reset // without this, error counting is not correct,
+ // and the interpreter sometimes overlooks compile failures!
+
+ // compile the main object
+ val objRun = new compiler.Run()
+ //println("source: "+objectSourceCode) //DEBUG
+ objRun.compileSources(
+ List(new BatchSourceFile("<console>", objectSourceCode.toCharArray))
+ )
+ if (reporter.hasErrors) return false
+
+
+ // extract and remember types
+ typeOf = findTypes(objRun)
+
+ // compile the result-extraction object
+ new compiler.Run().compileSources(
+ List(new BatchSourceFile("<console>", resultObjectSourceCode.toCharArray))
+ )
+
+ // success
+ !reporter.hasErrors
+ }
+
+ /** Dig the types of all bound variables out of the compiler run.
+ *
+ * @param objRun ...
+ * @return ...
+ */
+ def findTypes(objRun: compiler.Run): Map[Name, String] = {
+ def valAndVarNames = handlers.flatMap(_.valAndVarNames)
+ def defNames = handlers.flatMap(_.defNames)
+
+ def getTypes(names: List[Name], nameMap: Name=>Name): Map[Name, String] = {
+ /** the outermost wrapper object */
+ val outerResObjSym: Symbol =
+ compiler.definitions.getMember(compiler.definitions.EmptyPackage,
+ newTermName(objectName).toTypeName) // MATEI: added toTypeName
+
+ /** the innermost object inside the wrapper, found by
+ * following accessPath into the outer one. */
+ val resObjSym =
+ (accessPath.split("\\.")).foldLeft(outerResObjSym)((sym,name) =>
+ if(name == "") sym else
+ compiler.atPhase(objRun.typerPhase.next) {
+ sym.info.member(newTermName(name)) })
+
+ names.foldLeft(Map.empty[Name,String])((map, name) => {
+ val rawType =
+ compiler.atPhase(objRun.typerPhase.next) {
+ resObjSym.info.member(name).tpe
+ }
+
+ // the types are all =>T; remove the =>
+ val cleanedType= rawType match {
+ case compiler.PolyType(Nil, rt) => rt
+ case rawType => rawType
+ }
+
+ map + (name -> compiler.atPhase(objRun.typerPhase.next) { cleanedType.toString })
+ })
+ }
+
+ val names1 = getTypes(valAndVarNames, n => compiler.nme.getterToLocal(n))
+ val names2 = getTypes(defNames, identity)
+ names1 ++ names2
+ }
+
+ /** load and run the code using reflection */
+ def loadAndRun: (String, Boolean) = {
+ val interpreterResultObject: Class[_] =
+ Class.forName(resultObjectName, true, classLoader)
+ val resultValMethod: java.lang.reflect.Method =
+ interpreterResultObject.getMethod("result")
+ try {
+ (resultValMethod.invoke(interpreterResultObject).toString(),
+ true)
+ } catch {
+ case e =>
+ def caus(e: Throwable): Throwable =
+ if (e.getCause eq null) e else caus(e.getCause)
+ val orig = caus(e)
+ (stringFrom(str => orig.printStackTrace(str)), false)
+ }
+ }
+ }
+}
+
+/** Utility methods for the Interpreter. */
+object Interpreter {
+ /** Delete a directory tree recursively. Use with care!
+ */
+ def deleteRecursively(path: File) {
+ path match {
+ case _ if !path.exists =>
+ ()
+ case _ if path.isDirectory =>
+ for (p <- path.listFiles)
+ deleteRecursively(p)
+ path.delete
+ case _ =>
+ path.delete
+ }
+ }
+
+ /** Heuristically strip interpreter wrapper prefixes
+ * from an interpreter output string.
+ */
+ def stripWrapperGunk(str: String): String = {
+ //val wrapregex = "(line[0-9]+\\$object[$.])?(\\$iw[$.])*"
+ //str.replaceAll(wrapregex, "")
+ str
+ }
+
+ /** Convert a string into code that can recreate the string.
+ * This requires replacing all special characters by escape
+ * codes. It does not add the surrounding " marks. */
+ def string2code(str: String): String = {
+ /** Convert a character to a backslash-u escape */
+ def char2uescape(c: Char): String = {
+ var rest = c.toInt
+ val buf = new StringBuilder
+ for (i <- 1 to 4) {
+ buf ++= (rest % 16).toHexString
+ rest = rest / 16
+ }
+ "\\" + "u" + buf.toString.reverse
+ }
+
+
+ val res = new StringBuilder
+ for (c <- str) {
+ if ("'\"\\" contains c) {
+ res += '\\'
+ res += c
+ } else if (!c.isControl) {
+ res += c
+ } else {
+ res ++= char2uescape(c)
+ }
+ }
+ res.toString
+ }
+}
diff --git a/src/scala/spark/repl/SparkInterpreterLoop.scala b/src/scala/spark/repl/SparkInterpreterLoop.scala
new file mode 100644
index 0000000000..4aab60fd11
--- /dev/null
+++ b/src/scala/spark/repl/SparkInterpreterLoop.scala
@@ -0,0 +1,366 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2009 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+// $Id: InterpreterLoop.scala 16881 2009-01-09 16:28:11Z cunei $
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import java.io.{BufferedReader, File, FileReader, PrintWriter}
+import java.io.IOException
+import java.lang.{ClassLoader, System}
+
+import scala.tools.nsc.{InterpreterResults => IR}
+import scala.tools.nsc.interpreter._
+
+import spark.SparkContext
+
+/** The
+ * <a href="http://scala-lang.org/" target="_top">Scala</a>
+ * interactive shell. It provides a read-eval-print loop around
+ * the Interpreter class.
+ * After instantiation, clients should call the <code>main()</code> method.
+ *
+ * <p>If no in0 is specified, then input will come from the console, and
+ * the class will attempt to provide input editing feature such as
+ * input history.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ * @version 1.2
+ */
+class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter,
+ master: Option[String])
+{
+ def this(in0: BufferedReader, out: PrintWriter, master: String) =
+ this(Some(in0), out, Some(master))
+
+ def this(in0: BufferedReader, out: PrintWriter) =
+ this(Some(in0), out, None)
+
+ def this() = this(None, new PrintWriter(Console.out), None)
+
+ /** The input stream from which interpreter commands come */
+ var in: InteractiveReader = _ //set by main()
+
+ /** The context class loader at the time this object was created */
+ protected val originalClassLoader =
+ Thread.currentThread.getContextClassLoader
+
+ var settings: Settings = _ // set by main()
+ var interpreter: SparkInterpreter = null // set by createInterpreter()
+ def isettings = interpreter.isettings
+
+ /** A reverse list of commands to replay if the user
+ * requests a :replay */
+ var replayCommandsRev: List[String] = Nil
+
+ /** A list of commands to replay if the user requests a :replay */
+ def replayCommands = replayCommandsRev.reverse
+
+ /** Record a command for replay should the user requset a :replay */
+ def addReplay(cmd: String) =
+ replayCommandsRev = cmd :: replayCommandsRev
+
+ /** Close the interpreter, if there is one, and set
+ * interpreter to <code>null</code>. */
+ def closeInterpreter() {
+ if (interpreter ne null) {
+ interpreter.close
+ interpreter = null
+ Thread.currentThread.setContextClassLoader(originalClassLoader)
+ }
+ }
+
+ /** Create a new interpreter. Close the old one, if there
+ * is one. */
+ def createInterpreter() {
+ //closeInterpreter()
+
+ interpreter = new SparkInterpreter(settings, out) {
+ override protected def parentClassLoader =
+ classOf[SparkInterpreterLoop].getClassLoader
+ }
+ interpreter.setContextClassLoader()
+ }
+
+ /** Bind the settings so that evaluated code can modiy them */
+ def bindSettings() {
+ interpreter.beQuietDuring {
+ interpreter.compileString(InterpreterSettings.sourceCodeForClass)
+
+ interpreter.bind(
+ "settings",
+ "scala.tools.nsc.InterpreterSettings",
+ isettings)
+ }
+ }
+
+
+ /** print a friendly help message */
+ def printHelp {
+ //printWelcome
+ out.println("This is Scala " + Properties.versionString + " (" +
+ System.getProperty("java.vm.name") + ", Java " + System.getProperty("java.version") + ")." )
+ out.println("Type in expressions to have them evaluated.")
+ out.println("Type :load followed by a filename to load a Scala file.")
+ //out.println("Type :replay to reset execution and replay all previous commands.")
+ out.println("Type :quit to exit the interpreter.")
+ }
+
+ /** Print a welcome message */
+ def printWelcome() {
+ out.println("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version 0.0
+ /_/
+""")
+
+ out.println("Using Scala " + Properties.versionString + " (" +
+ System.getProperty("java.vm.name") + ", Java " +
+ System.getProperty("java.version") + ")." )
+ out.flush()
+ }
+
+ def createSparkContext(): SparkContext = {
+ val master = this.master match {
+ case Some(m) => m
+ case None => {
+ val prop = System.getenv("MASTER")
+ if (prop != null) prop else "local"
+ }
+ }
+ new SparkContext(master, "Spark REPL")
+ }
+
+ /** Prompt to print when awaiting input */
+ val prompt = Properties.shellPromptString
+
+ /** The main read-eval-print loop for the interpreter. It calls
+ * <code>command()</code> for each line of input, and stops when
+ * <code>command()</code> returns <code>false</code>.
+ */
+ def repl() {
+ out.println("Intializing...")
+ out.flush()
+ interpreter.beQuietDuring {
+ command("""
+ spark.repl.Main.interp.out.println("Registering with Nexus...");
+ @transient val sc = spark.repl.Main.interp.createSparkContext();
+ sc.waitForRegister();
+ spark.repl.Main.interp.out.println("Spark context available as sc.")
+ """)
+ command("import spark.SparkContext._");
+ }
+ out.println("Type in expressions to have them evaluated.")
+ out.println("Type :help for more information.")
+ out.flush()
+
+ var first = true
+ while (true) {
+ out.flush()
+
+ val line =
+ if (first) {
+ /* For some reason, the first interpreted command always takes
+ * a second or two. So, wait until the welcome message
+ * has been printed before calling bindSettings. That way,
+ * the user can read the welcome message while this
+ * command executes.
+ */
+ val futLine = scala.concurrent.ops.future(in.readLine(prompt))
+
+ bindSettings()
+ first = false
+
+ futLine()
+ } else {
+ in.readLine(prompt)
+ }
+
+ if (line eq null)
+ return () // assumes null means EOF
+
+ val (keepGoing, finalLineMaybe) = command(line)
+
+ if (!keepGoing)
+ return
+
+ finalLineMaybe match {
+ case Some(finalLine) => addReplay(finalLine)
+ case None => ()
+ }
+ }
+ }
+
+ /** interpret all lines from a specified file */
+ def interpretAllFrom(filename: String) {
+ val fileIn = try {
+ new FileReader(filename)
+ } catch {
+ case _:IOException =>
+ out.println("Error opening file: " + filename)
+ return
+ }
+ val oldIn = in
+ val oldReplay = replayCommandsRev
+ try {
+ val inFile = new BufferedReader(fileIn)
+ in = new SimpleReader(inFile, out, false)
+ out.println("Loading " + filename + "...")
+ out.flush
+ repl
+ } finally {
+ in = oldIn
+ replayCommandsRev = oldReplay
+ fileIn.close
+ }
+ }
+
+ /** create a new interpreter and replay all commands so far */
+ def replay() {
+ closeInterpreter()
+ createInterpreter()
+ for (cmd <- replayCommands) {
+ out.println("Replaying: " + cmd)
+ out.flush() // because maybe cmd will have its own output
+ command(cmd)
+ out.println
+ }
+ }
+
+ /** Run one command submitted by the user. Three values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any. */
+ def command(line: String): (Boolean, Option[String]) = {
+ def withFile(command: String)(action: String => Unit) {
+ val spaceIdx = command.indexOf(' ')
+ if (spaceIdx <= 0) {
+ out.println("That command requires a filename to be specified.")
+ return ()
+ }
+ val filename = command.substring(spaceIdx).trim
+ if (! new File(filename).exists) {
+ out.println("That file does not exist")
+ return ()
+ }
+ action(filename)
+ }
+
+ val helpRegexp = ":h(e(l(p)?)?)?"
+ val quitRegexp = ":q(u(i(t)?)?)?"
+ val loadRegexp = ":l(o(a(d)?)?)?.*"
+ //val replayRegexp = ":r(e(p(l(a(y)?)?)?)?)?.*"
+
+ var shouldReplay: Option[String] = None
+
+ if (line.matches(helpRegexp))
+ printHelp
+ else if (line.matches(quitRegexp))
+ return (false, None)
+ else if (line.matches(loadRegexp)) {
+ withFile(line)(f => {
+ interpretAllFrom(f)
+ shouldReplay = Some(line)
+ })
+ }
+ //else if (line matches replayRegexp)
+ // replay
+ else if (line startsWith ":")
+ out.println("Unknown command. Type :help for help.")
+ else
+ shouldReplay = interpretStartingWith(line)
+
+ (true, shouldReplay)
+ }
+
+ /** Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
+ def interpretStartingWith(code: String): Option[String] = {
+ interpreter.interpret(code) match {
+ case IR.Success => Some(code)
+ case IR.Error => None
+ case IR.Incomplete =>
+ if (in.interactive && code.endsWith("\n\n")) {
+ out.println("You typed two blank lines. Starting a new command.")
+ None
+ } else {
+ val nextLine = in.readLine(" | ")
+ if (nextLine == null)
+ None // end of file
+ else
+ interpretStartingWith(code + "\n" + nextLine)
+ }
+ }
+ }
+
+ def loadFiles(settings: Settings) {
+ settings match {
+ case settings: GenericRunnerSettings =>
+ for (filename <- settings.loadfiles.value) {
+ val cmd = ":load " + filename
+ command(cmd)
+ replayCommandsRev = cmd :: replayCommandsRev
+ out.println()
+ }
+ case _ =>
+ }
+ }
+
+ def main(settings: Settings) {
+ this.settings = settings
+
+ in =
+ in0 match {
+ case Some(in0) =>
+ new SimpleReader(in0, out, true)
+
+ case None =>
+ val emacsShell = System.getProperty("env.emacs", "") != ""
+ //println("emacsShell="+emacsShell) //debug
+ if (settings.Xnojline.value || emacsShell)
+ new SimpleReader()
+ else
+ InteractiveReader.createDefault()
+ }
+
+ createInterpreter()
+
+ loadFiles(settings)
+
+ try {
+ if (interpreter.reporter.hasErrors) {
+ return // it is broken on startup; go ahead and exit
+ }
+ printWelcome()
+ repl()
+ } finally {
+ closeInterpreter()
+ }
+ }
+
+ /** process command-line arguments and do as they request */
+ def main(args: Array[String]) {
+ def error1(msg: String) { out.println("scala: " + msg) }
+ val command = new InterpreterCommand(List.fromArray(args), error1)
+
+ if (!command.ok || command.settings.help.value || command.settings.Xhelp.value) {
+ // either the command line is wrong, or the user
+ // explicitly requested a help listing
+ if (command.settings.help.value) out.println(command.usageMsg)
+ if (command.settings.Xhelp.value) out.println(command.xusageMsg)
+ out.flush
+ }
+ else
+ main(command.settings)
+ }
+}
diff --git a/src/scala/ubiquifs/Header.scala b/src/scala/ubiquifs/Header.scala
new file mode 100644
index 0000000000..bdca83a2d5
--- /dev/null
+++ b/src/scala/ubiquifs/Header.scala
@@ -0,0 +1,21 @@
+package ubiquifs
+
+import java.io.{DataInputStream, DataOutputStream}
+
+object RequestType {
+ val READ = 0
+ val WRITE = 1
+}
+
+class Header(val requestType: Int, val path: String) {
+ def write(out: DataOutputStream) {
+ out.write(requestType)
+ out.writeUTF(path)
+ }
+}
+
+object Header {
+ def read(in: DataInputStream): Header = {
+ new Header(in.read(), in.readUTF())
+ }
+}
diff --git a/src/scala/ubiquifs/Master.scala b/src/scala/ubiquifs/Master.scala
new file mode 100644
index 0000000000..6854acd6a5
--- /dev/null
+++ b/src/scala/ubiquifs/Master.scala
@@ -0,0 +1,49 @@
+package ubiquifs
+
+import scala.actors.Actor
+import scala.actors.Actor._
+import scala.actors.remote.RemoteActor
+import scala.actors.remote.RemoteActor._
+import scala.actors.remote.Node
+import scala.collection.mutable.{ArrayBuffer, Map, Set}
+
+class Master(port: Int) extends Actor {
+ case class SlaveInfo(host: String, port: Int)
+
+ val files = Set[String]()
+ val slaves = new ArrayBuffer[SlaveInfo]()
+
+ def act() {
+ alive(port)
+ register('UbiquiFS, self)
+ println("Created UbiquiFS Master on port " + port)
+
+ loop {
+ react {
+ case RegisterSlave(host, port) =>
+ slaves += SlaveInfo(host, port)
+ sender ! RegisterSucceeded()
+
+ case Create(path) =>
+ if (files.contains(path)) {
+ sender ! CreateFailed("File already exists")
+ } else if (slaves.isEmpty) {
+ sender ! CreateFailed("No slaves registered")
+ } else {
+ files += path
+ sender ! CreateSucceeded(slaves(0).host, slaves(0).port)
+ }
+
+ case m: Any =>
+ println("Unknown message: " + m)
+ }
+ }
+ }
+}
+
+object MasterMain {
+ def main(args: Array[String]) {
+ val port = args(0).toInt
+ new Master(port).start()
+ }
+}
diff --git a/src/scala/ubiquifs/Message.scala b/src/scala/ubiquifs/Message.scala
new file mode 100644
index 0000000000..153542f8de
--- /dev/null
+++ b/src/scala/ubiquifs/Message.scala
@@ -0,0 +1,14 @@
+package ubiquifs
+
+sealed case class Message()
+
+case class RegisterSlave(host: String, port: Int) extends Message
+case class RegisterSucceeded() extends Message
+
+case class Create(path: String) extends Message
+case class CreateSucceeded(host: String, port: Int) extends Message
+case class CreateFailed(message: String) extends Message
+
+case class Read(path: String) extends Message
+case class ReadSucceeded(host: String, port: Int) extends Message
+case class ReadFailed(message: String) extends Message
diff --git a/src/scala/ubiquifs/Slave.scala b/src/scala/ubiquifs/Slave.scala
new file mode 100644
index 0000000000..328b73c828
--- /dev/null
+++ b/src/scala/ubiquifs/Slave.scala
@@ -0,0 +1,141 @@
+package ubiquifs
+
+import java.io.{DataInputStream, DataOutputStream, IOException}
+import java.net.{InetAddress, Socket, ServerSocket}
+import java.util.concurrent.locks.ReentrantLock
+
+import scala.actors.Actor
+import scala.actors.Actor._
+import scala.actors.remote.RemoteActor
+import scala.actors.remote.RemoteActor._
+import scala.actors.remote.Node
+import scala.collection.mutable.{ArrayBuffer, Map, Set}
+
+class Slave(myPort: Int, master: String) extends Thread("UbiquiFS slave") {
+ val CHUNK_SIZE = 1024 * 1024
+
+ val buffers = Map[String, Buffer]()
+
+ override def run() {
+ // Create server socket
+ val socket = new ServerSocket(myPort)
+
+ // Register with master
+ val (masterHost, masterPort) = Utils.parseHostPort(master)
+ val masterActor = select(Node(masterHost, masterPort), 'UbiquiFS)
+ val myHost = InetAddress.getLocalHost.getHostName
+ val reply = masterActor !? RegisterSlave(myHost, myPort)
+ println("Registered with master, reply = " + reply)
+
+ while (true) {
+ val conn = socket.accept()
+ new ConnectionHandler(conn).start()
+ }
+ }
+
+ class ConnectionHandler(conn: Socket) extends Thread("ConnectionHandler") {
+ try {
+ val in = new DataInputStream(conn.getInputStream)
+ val out = new DataOutputStream(conn.getOutputStream)
+ val header = Header.read(in)
+ header.requestType match {
+ case RequestType.READ =>
+ performRead(header.path, out)
+ case RequestType.WRITE =>
+ performWrite(header.path, in)
+ case other =>
+ throw new IOException("Invalid header type " + other)
+ }
+ println("hi")
+ } catch {
+ case e: Exception => e.printStackTrace()
+ } finally {
+ conn.close()
+ }
+ }
+
+ def performWrite(path: String, in: DataInputStream) {
+ var buffer = new Buffer()
+ synchronized {
+ if (buffers.contains(path))
+ throw new IllegalArgumentException("Path " + path + " already exists")
+ buffers(path) = buffer
+ }
+ var chunk = new Array[Byte](CHUNK_SIZE)
+ var pos = 0
+ while (true) {
+ var numRead = in.read(chunk, pos, chunk.size - pos)
+ if (numRead == -1) {
+ buffer.addChunk(chunk.subArray(0, pos), true)
+ return
+ } else {
+ pos += numRead
+ if (pos == chunk.size) {
+ buffer.addChunk(chunk, false)
+ chunk = new Array[Byte](CHUNK_SIZE)
+ pos = 0
+ }
+ }
+ }
+ // TODO: launch a thread to write the data to disk, and when this finishes,
+ // remove the hard reference to buffer
+ }
+
+ def performRead(path: String, out: DataOutputStream) {
+ var buffer: Buffer = null
+ synchronized {
+ if (!buffers.contains(path))
+ throw new IllegalArgumentException("Path " + path + " doesn't exist")
+ buffer = buffers(path)
+ }
+ for (chunk <- buffer.iterator) {
+ out.write(chunk, 0, chunk.size)
+ }
+ }
+
+ class Buffer {
+ val chunks = new ArrayBuffer[Array[Byte]]
+ var finished = false
+ val mutex = new ReentrantLock
+ val chunksAvailable = mutex.newCondition()
+
+ def addChunk(chunk: Array[Byte], finish: Boolean) {
+ mutex.lock()
+ chunks += chunk
+ finished = finish
+ chunksAvailable.signalAll()
+ mutex.unlock()
+ }
+
+ def iterator = new Iterator[Array[Byte]] {
+ var index = 0
+
+ def hasNext: Boolean = {
+ mutex.lock()
+ while (index >= chunks.size && !finished)
+ chunksAvailable.await()
+ val ret = (index < chunks.size)
+ mutex.unlock()
+ return ret
+ }
+
+ def next: Array[Byte] = {
+ mutex.lock()
+ if (!hasNext)
+ throw new NoSuchElementException("End of file")
+ val ret = chunks(index) // hasNext ensures we advance past index
+ index += 1
+ mutex.unlock()
+ return ret
+ }
+ }
+ }
+}
+
+object SlaveMain {
+ def main(args: Array[String]) {
+ val port = args(0).toInt
+ val master = args(1)
+ new Slave(port, master).start()
+ }
+}
diff --git a/src/scala/ubiquifs/UbiquiFS.scala b/src/scala/ubiquifs/UbiquiFS.scala
new file mode 100644
index 0000000000..9ce0fd4f44
--- /dev/null
+++ b/src/scala/ubiquifs/UbiquiFS.scala
@@ -0,0 +1,11 @@
+package ubiquifs
+
+import java.io.{InputStream, OutputStream}
+
+class UbiquiFS(master: String) {
+ private val (masterHost, masterPort) = Utils.parseHostPort(master)
+
+ def create(path: String): OutputStream = null
+
+ def open(path: String): InputStream = null
+}
diff --git a/src/scala/ubiquifs/Utils.scala b/src/scala/ubiquifs/Utils.scala
new file mode 100644
index 0000000000..d6fd3f0181
--- /dev/null
+++ b/src/scala/ubiquifs/Utils.scala
@@ -0,0 +1,12 @@
+package ubiquifs
+
+private[ubiquifs] object Utils {
+ private val HOST_PORT_RE = "([a-zA-Z0-9.-]+):([0-9]+)".r
+
+ def parseHostPort(string: String): (String, Int) = {
+ string match {
+ case HOST_PORT_RE(host, port) => (host, port.toInt)
+ case _ => throw new IllegalArgumentException(string)
+ }
+ }
+}
diff --git a/src/test/spark/ParallelArraySplitSuite.scala b/src/test/spark/ParallelArraySplitSuite.scala
new file mode 100644
index 0000000000..a1787bd6dd
--- /dev/null
+++ b/src/test/spark/ParallelArraySplitSuite.scala
@@ -0,0 +1,161 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+class ParallelArraySplitSuite extends FunSuite with Checkers {
+ test("one element per slice") {
+ val data = Array(1, 2, 3)
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1")
+ assert(slices(1).mkString(",") === "2")
+ assert(slices(2).mkString(",") === "3")
+ }
+
+ test("one slice") {
+ val data = Array(1, 2, 3)
+ val slices = ParallelArray.slice(data, 1)
+ assert(slices.size === 1)
+ assert(slices(0).mkString(",") === "1,2,3")
+ }
+
+ test("equal slices") {
+ val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1,2,3")
+ assert(slices(1).mkString(",") === "4,5,6")
+ assert(slices(2).mkString(",") === "7,8,9")
+ }
+
+ test("non-equal slices") {
+ val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === "1,2,3")
+ assert(slices(1).mkString(",") === "4,5,6")
+ assert(slices(2).mkString(",") === "7,8,9,10")
+ }
+
+ test("splitting exclusive range") {
+ val data = 0 until 100
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === (0 to 32).mkString(","))
+ assert(slices(1).mkString(",") === (33 to 65).mkString(","))
+ assert(slices(2).mkString(",") === (66 to 99).mkString(","))
+ }
+
+ test("splitting inclusive range") {
+ val data = 0 to 100
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices(0).mkString(",") === (0 to 32).mkString(","))
+ assert(slices(1).mkString(",") === (33 to 66).mkString(","))
+ assert(slices(2).mkString(",") === (67 to 100).mkString(","))
+ }
+
+ test("empty data") {
+ val data = new Array[Int](0)
+ val slices = ParallelArray.slice(data, 5)
+ assert(slices.size === 5)
+ for (slice <- slices) assert(slice.size === 0)
+ }
+
+ test("zero slices") {
+ val data = Array(1, 2, 3)
+ intercept[IllegalArgumentException] { ParallelArray.slice(data, 0) }
+ }
+
+ test("negative number of slices") {
+ val data = Array(1, 2, 3)
+ intercept[IllegalArgumentException] { ParallelArray.slice(data, -5) }
+ }
+
+ test("exclusive ranges sliced into ranges") {
+ val data = 1 until 100
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 99)
+ assert(slices.forall(_.isInstanceOf[SerializableRange]))
+ }
+
+ test("inclusive ranges sliced into ranges") {
+ val data = 1 to 100
+ val slices = ParallelArray.slice(data, 3)
+ assert(slices.size === 3)
+ assert(slices.map(_.size).reduceLeft(_+_) === 100)
+ assert(slices.forall(_.isInstanceOf[SerializableRange]))
+ }
+
+ test("large ranges don't overflow") {
+ val N = 100 * 1000 * 1000
+ val data = 0 until N
+ val slices = ParallelArray.slice(data, 40)
+ assert(slices.size === 40)
+ for (i <- 0 until 40) {
+ assert(slices(i).isInstanceOf[SerializableRange])
+ val range = slices(i).asInstanceOf[SerializableRange]
+ assert(range.start === i * (N / 40), "slice " + i + " start")
+ assert(range.end === (i+1) * (N / 40), "slice " + i + " end")
+ assert(range.step === 1, "slice " + i + " step")
+ }
+ }
+
+ test("random array tests") {
+ val gen = for {
+ d <- arbitrary[List[Int]]
+ n <- Gen.choose(1, 100)
+ } yield (d, n)
+ val prop = forAll(gen) {
+ (tuple: (List[Int], Int)) =>
+ val d = tuple._1
+ val n = tuple._2
+ val slices = ParallelArray.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+
+ test("random exclusive range tests") {
+ val gen = for {
+ a <- Gen.choose(-100, 100)
+ b <- Gen.choose(-100, 100)
+ step <- Gen.choose(-5, 5) suchThat (_ != 0)
+ n <- Gen.choose(1, 100)
+ } yield (a until b by step, n)
+ val prop = forAll(gen) {
+ case (d: Range, n: Int) =>
+ val slices = ParallelArray.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("all ranges" |: slices.forall(_.isInstanceOf[SerializableRange])) &&
+ ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+
+ test("random inclusive range tests") {
+ val gen = for {
+ a <- Gen.choose(-100, 100)
+ b <- Gen.choose(-100, 100)
+ step <- Gen.choose(-5, 5) suchThat (_ != 0)
+ n <- Gen.choose(1, 100)
+ } yield (a to b by step, n)
+ val prop = forAll(gen) {
+ case (d: Range, n: Int) =>
+ val slices = ParallelArray.slice(d, n)
+ ("n slices" |: slices.size == n) &&
+ ("all ranges" |: slices.forall(_.isInstanceOf[SerializableRange])) &&
+ ("concat to d" |: Array.concat(slices: _*).mkString(",") == d.mkString(",")) &&
+ ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
+ }
+ check(prop)
+ }
+}
diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala
new file mode 100644
index 0000000000..d71fe20a94
--- /dev/null
+++ b/src/test/spark/repl/ReplSuite.scala
@@ -0,0 +1,124 @@
+package spark.repl
+
+import java.io._
+
+import org.scalatest.FunSuite
+
+class ReplSuite extends FunSuite {
+ def runInterpreter(master: String, input: String): String = {
+ val in = new BufferedReader(new StringReader(input + "\n"))
+ val out = new StringWriter()
+ val interp = new SparkInterpreterLoop(in, new PrintWriter(out), master)
+ spark.repl.Main.interp = interp
+ interp.main(new Array[String](0))
+ spark.repl.Main.interp = null
+ return out.toString
+ }
+
+ def assertContains(message: String, output: String) {
+ assert(output contains message,
+ "Interpreter output did not contain '" + message + "':\n" + output)
+ }
+
+ def assertDoesNotContain(message: String, output: String) {
+ assert(!(output contains message),
+ "Interpreter output contained '" + message + "':\n" + output)
+ }
+
+ test ("simple foreach with accumulator") {
+ val output = runInterpreter("local", """
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 10).foreach(x => accum += x)
+ accum.value
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res1: Int = 55", output)
+ }
+
+ test ("external vars") {
+ val output = runInterpreter("local", """
+ var v = 7
+ sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res2: Int = 100", output)
+ }
+
+ test ("external classes") {
+ val output = runInterpreter("local", """
+ class C {
+ def foo = 5
+ }
+ sc.parallelize(1 to 10).map(x => (new C).foo).toArray.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 50", output)
+ }
+
+ test ("external functions") {
+ val output = runInterpreter("local", """
+ def double(x: Int) = x + x
+ sc.parallelize(1 to 10).map(x => double(x)).toArray.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 110", output)
+ }
+
+ test ("external functions that access vars") {
+ val output = runInterpreter("local", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res2: Int = 100", output)
+ }
+
+ test ("cached vars") {
+ // Test that the value that a cached var had when it was created is used,
+ // even if that cached var is then modified in the driver program
+ val output = runInterpreter("local", """
+ var array = new Array[Int](5)
+ val cachedArray = sc.cache(array)
+ sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
+ }
+
+ test ("running on Nexus") {
+ val output = runInterpreter("localquiet", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ var array = new Array[Int](5)
+ val cachedArray = sc.cache(array)
+ sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res2: Int = 100", output)
+ assertContains("res3: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res5: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+}