aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala73
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala69
-rw-r--r--core/src/main/scala/spark/Broadcast.scala799
-rw-r--r--core/src/main/scala/spark/Cache.scala63
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala159
-rw-r--r--core/src/main/scala/spark/DfsShuffle.scala120
-rw-r--r--core/src/main/scala/spark/Executor.scala116
-rw-r--r--core/src/main/scala/spark/HadoopFile.scala118
-rw-r--r--core/src/main/scala/spark/HttpServer.scala67
-rw-r--r--core/src/main/scala/spark/Job.scala18
-rw-r--r--core/src/main/scala/spark/LocalFileShuffle.scala171
-rw-r--r--core/src/main/scala/spark/LocalScheduler.scala72
-rw-r--r--core/src/main/scala/spark/Logging.scala49
-rw-r--r--core/src/main/scala/spark/MesosScheduler.scala294
-rw-r--r--core/src/main/scala/spark/NumberedSplitRDD.scala42
-rw-r--r--core/src/main/scala/spark/ParallelArray.scala76
-rw-r--r--core/src/main/scala/spark/RDD.scala418
-rw-r--r--core/src/main/scala/spark/Scheduler.scala10
-rw-r--r--core/src/main/scala/spark/SerializableWritable.scala26
-rw-r--r--core/src/main/scala/spark/Shuffle.scala15
-rw-r--r--core/src/main/scala/spark/SimpleJob.scala272
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala160
-rw-r--r--core/src/main/scala/spark/SoftReferenceCache.scala13
-rw-r--r--core/src/main/scala/spark/SparkContext.scala175
-rw-r--r--core/src/main/scala/spark/SparkException.scala3
-rw-r--r--core/src/main/scala/spark/Split.scala13
-rw-r--r--core/src/main/scala/spark/Task.scala16
-rw-r--r--core/src/main/scala/spark/TaskResult.scala9
-rw-r--r--core/src/main/scala/spark/Utils.scala127
-rw-r--r--core/src/main/scala/spark/WeakReferenceCache.scala14
-rw-r--r--core/src/main/scala/spark/repl/ExecutorClassLoader.scala108
-rw-r--r--core/src/main/scala/spark/repl/Main.scala16
-rw-r--r--core/src/main/scala/spark/repl/SparkCompletion.scala353
-rw-r--r--core/src/main/scala/spark/repl/SparkCompletionOutput.scala92
-rw-r--r--core/src/main/scala/spark/repl/SparkInteractiveReader.scala60
-rw-r--r--core/src/main/scala/spark/repl/SparkInterpreter.scala1395
-rw-r--r--core/src/main/scala/spark/repl/SparkInterpreterLoop.scala659
-rw-r--r--core/src/main/scala/spark/repl/SparkInterpreterSettings.scala112
-rw-r--r--core/src/main/scala/spark/repl/SparkJLineReader.scala38
-rw-r--r--core/src/main/scala/spark/repl/SparkSimpleReader.scala33
40 files changed, 6443 insertions, 0 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
new file mode 100644
index 0000000000..ee93d3c85c
--- /dev/null
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -0,0 +1,73 @@
+package spark
+
+import java.io._
+
+import scala.collection.mutable.Map
+
+@serializable class Accumulator[T](
+ @transient initialValue: T, param: AccumulatorParam[T])
+{
+ val id = Accumulators.newId
+ @transient var value_ = initialValue // Current value on master
+ val zero = param.zero(initialValue) // Zero value to be passed to workers
+ var deserialized = false
+
+ Accumulators.register(this)
+
+ def += (term: T) { value_ = param.addInPlace(value_, term) }
+ def value = this.value_
+ def value_= (t: T) {
+ if (!deserialized) value_ = t
+ else throw new UnsupportedOperationException("Can't use value_= in task")
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject
+ value_ = zero
+ deserialized = true
+ Accumulators.register(this)
+ }
+
+ override def toString = value_.toString
+}
+
+@serializable trait AccumulatorParam[T] {
+ def addInPlace(t1: T, t2: T): T
+ def zero(initialValue: T): T
+}
+
+// TODO: The multi-thread support in accumulators is kind of lame; check
+// if there's a more intuitive way of doing it right
+private object Accumulators
+{
+ // TODO: Use soft references? => need to make readObject work properly then
+ val accums = Map[(Thread, Long), Accumulator[_]]()
+ var lastId: Long = 0
+
+ def newId: Long = synchronized { lastId += 1; return lastId }
+
+ def register(a: Accumulator[_]): Unit = synchronized {
+ accums((currentThread, a.id)) = a
+ }
+
+ def clear: Unit = synchronized {
+ accums.retain((key, accum) => key._1 != currentThread)
+ }
+
+ def values: Map[Long, Any] = synchronized {
+ val ret = Map[Long, Any]()
+ for(((thread, id), accum) <- accums if thread == currentThread)
+ ret(id) = accum.value
+ return ret
+ }
+
+ def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
+ for ((id, value) <- values) {
+ if (accums.contains((thread, id))) {
+ val accum = accums((thread, id))
+ accum.asInstanceOf[Accumulator[Any]] += value
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
new file mode 100644
index 0000000000..19d9bebfe5
--- /dev/null
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -0,0 +1,69 @@
+package spark
+
+import java.util.LinkedHashMap
+
+/**
+ * An implementation of Cache that estimates the sizes of its entries and
+ * attempts to limit its total memory usage to a fraction of the JVM heap.
+ * Objects' sizes are estimated using SizeEstimator, which has limitations;
+ * most notably, we will overestimate total memory used if some cache
+ * entries have pointers to a shared object. Nonetheless, this Cache should
+ * work well when most of the space is used by arrays of primitives or of
+ * simple classes.
+ */
+class BoundedMemoryCache extends Cache with Logging {
+ private val maxBytes: Long = getMaxBytes()
+ logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
+
+ private var currentBytes = 0L
+ private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true)
+
+ // An entry in our map; stores a cached object and its size in bytes
+ class Entry(val value: Any, val size: Long) {}
+
+ override def get(key: Any): Any = {
+ synchronized {
+ val entry = map.get(key)
+ if (entry != null) entry.value else null
+ }
+ }
+
+ override def put(key: Any, value: Any) {
+ logInfo("Asked to add key " + key)
+ val startTime = System.currentTimeMillis
+ val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Estimated size for key %s is %d".format(key, size))
+ logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
+ synchronized {
+ ensureFreeSpace(size)
+ logInfo("Adding key " + key)
+ map.put(key, new Entry(value, size))
+ currentBytes += size
+ logInfo("Number of entries is now " + map.size)
+ }
+ }
+
+ private def getMaxBytes(): Long = {
+ val memoryFractionToUse = System.getProperty(
+ "spark.boundedMemoryCache.memoryFraction", "0.75").toDouble
+ (Runtime.getRuntime.totalMemory * memoryFractionToUse).toLong
+ }
+
+ /**
+ * Remove least recently used entries from the map until at least space
+ * bytes are free. Assumes that a lock is held on the BoundedMemoryCache.
+ */
+ private def ensureFreeSpace(space: Long) {
+ logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format(
+ space, currentBytes, maxBytes))
+ val iter = map.entrySet.iterator
+ while (maxBytes - currentBytes < space && iter.hasNext) {
+ val mapEntry = iter.next()
+ logInfo("Dropping key %s of size %d to make space".format(
+ mapEntry.getKey, mapEntry.getValue.size))
+ currentBytes -= mapEntry.getValue.size
+ iter.remove()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/Broadcast.scala b/core/src/main/scala/spark/Broadcast.scala
new file mode 100644
index 0000000000..5089dca82e
--- /dev/null
+++ b/core/src/main/scala/spark/Broadcast.scala
@@ -0,0 +1,799 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{UUID, PriorityQueue, Comparator}
+
+import java.util.concurrent.{Executors, ExecutorService}
+
+import scala.actors.Actor
+import scala.actors.Actor._
+
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+@serializable
+trait BroadcastRecipe {
+ val uuid = UUID.randomUUID
+
+ // We cannot have an abstract readObject here due to some weird issues with
+ // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
+ def sendBroadcast: Unit
+
+ override def toString = "spark.Broadcast(" + uuid + ")"
+}
+
+// TODO: Right, now no parallelization between multiple broadcasts
+@serializable
+class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean)
+extends BroadcastRecipe with Logging {
+
+ def value = value_
+
+ BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) }
+
+ if (!local) { sendBroadcast }
+
+ def sendBroadcast () {
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
+ // TODO: Even though this part is not in use now, there is problem in the
+ // following statement. Shouldn't use constant port and hostAddress anymore?
+ // val masterSource =
+ // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort,
+ // variableInfo.totalBlocks, variableInfo.totalBytes, 0)
+ // variableInfo.pqOfSources.add (masterSource)
+
+ BroadcastCS.synchronized {
+ // BroadcastCS.valueInfos.put (uuid, variableInfo)
+
+ // TODO: Not using variableInfo in current implementation. Manually
+ // setting all the variables inside BroadcastCS object
+
+ BroadcastCS.initializeVariable (variableInfo)
+ }
+
+ // Now store a persistent copy in HDFS, just in case
+ val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+ out.writeObject (value_)
+ out.close
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject (in: ObjectInputStream) {
+ in.defaultReadObject
+ BroadcastCS.synchronized {
+ val cachedVal = BroadcastCS.values.get (uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ // Only a single worker (the first one) in the same node can ever be
+ // here. The rest will always get the value ready.
+ val start = System.nanoTime
+
+ val retByteArray = BroadcastCS.receiveBroadcast (uuid)
+ // If does not succeed, then get from HDFS copy
+ if (retByteArray != null) {
+ value_ = byteArrayToObject[T] (retByteArray)
+ BroadcastCS.values.put (uuid, value_)
+ // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize)
+ // BroadcastCS.valueInfos.put (uuid, variableInfo)
+ } else {
+ val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ BroadcastCH.values.put(uuid, value_)
+ fileIn.close
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+
+ private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream (baos)
+ oos.writeObject (obj)
+ oos.close
+ baos.close
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream (byteArray)
+
+ var blockNum = (byteArray.length / blockSize)
+ if (byteArray.length % blockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock] (blockNum)
+ var blockID = 0
+
+ // TODO: What happens in byteArray.length == 0 => blockNum == 0
+ for (i <- 0 until (byteArray.length, blockSize)) {
+ val thisBlockSize = Math.min (blockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte] (thisBlockSize)
+ val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
+
+ retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close
+
+ var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
+ val retVal = in.readObject.asInstanceOf[A]
+ in.close
+ return retVal
+ }
+
+ private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = {
+ val bOut = new ByteArrayOutputStream
+ val out = new ObjectOutputStream (bOut)
+ out.writeObject (obj)
+ out.close
+ bOut.close
+ return bOut
+ }
+}
+
+@serializable
+class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean)
+extends BroadcastRecipe with Logging {
+
+ def value = value_
+
+ BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) }
+
+ if (!local) { sendBroadcast }
+
+ def sendBroadcast () {
+ val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+ out.writeObject (value_)
+ out.close
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject
+ BroadcastCH.synchronized {
+ val cachedVal = BroadcastCH.values.get(uuid)
+ if (cachedVal != null) {
+ value_ = cachedVal.asInstanceOf[T]
+ } else {
+ val start = System.nanoTime
+
+ val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+ value_ = fileIn.readObject.asInstanceOf[T]
+ BroadcastCH.values.put(uuid, value_)
+ fileIn.close
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
+ }
+ }
+ }
+}
+
+@serializable
+case class SourceInfo (val hostAddress: String, val listenPort: Int,
+ val totalBlocks: Int, val totalBytes: Int, val replicaID: Int)
+extends Comparable[SourceInfo]{
+
+ var currentLeechers = 0
+ var receptionFailed = false
+
+ def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
+}
+
+@serializable
+case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
+
+@serializable
+case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock],
+ val totalBlocks: Int, val totalBytes: Int) {
+
+ @transient var hasBlocks = 0
+
+ val listenPortLock = new AnyRef
+ val totalBlocksLock = new AnyRef
+ val hasBlocksLock = new AnyRef
+
+ @transient var pqOfSources = new PriorityQueue[SourceInfo]
+}
+
+private object Broadcast {
+ private var initialized = false
+
+ // Will be called by SparkContext or Executor before using Broadcast
+ // Calls all other initializers here
+ def initialize (isMaster: Boolean) {
+ synchronized {
+ if (!initialized) {
+ // Initialization for CentralizedHDFSBroadcast
+ BroadcastCH.initialize
+ // Initialization for ChainedStreamingBroadcast
+ // BroadcastCS.initialize (isMaster)
+
+ initialized = true
+ }
+ }
+ }
+}
+
+private object BroadcastCS extends Logging {
+ val values = Cache.newKeySpace()
+
+ // private var valueToPort = Map[UUID, Int] ()
+
+ private var initialized = false
+ private var isMaster_ = false
+
+ private var masterHostAddress_ = "127.0.0.1"
+ private var masterListenPort_ : Int = 11111
+ private var blockSize_ : Int = 512 * 1024
+ private var maxRetryCount_ : Int = 2
+ private var serverSocketTimout_ : Int = 50000
+ private var dualMode_ : Boolean = false
+
+ private val hostAddress = InetAddress.getLocalHost.getHostAddress
+ private var listenPort = -1
+
+ var arrayOfBlocks: Array[BroadcastBlock] = null
+ var totalBytes = -1
+ var totalBlocks = -1
+ var hasBlocks = 0
+
+ val listenPortLock = new Object
+ val totalBlocksLock = new Object
+ val hasBlocksLock = new Object
+
+ var pqOfSources = new PriorityQueue[SourceInfo]
+
+ private var serveMR: ServeMultipleRequests = null
+ private var guideMR: GuideMultipleRequests = null
+
+ def initialize (isMaster__ : Boolean) {
+ synchronized {
+ if (!initialized) {
+ masterHostAddress_ =
+ System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1")
+ masterListenPort_ =
+ System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt
+ blockSize_ =
+ System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
+ maxRetryCount_ =
+ System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt
+ serverSocketTimout_ =
+ System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt
+ dualMode_ =
+ System.getProperty ("spark.broadcast.dualMode", "false").toBoolean
+
+ isMaster_ = isMaster__
+
+ if (isMaster) {
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon (true)
+ guideMR.start
+ logInfo("GuideMultipleRequests started")
+ }
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon (true)
+ serveMR.start
+ logInfo("ServeMultipleRequests started")
+
+ logInfo("BroadcastCS object has been initialized")
+
+ initialized = true
+ }
+ }
+ }
+
+ // TODO: This should change in future implementation.
+ // Called from the Master constructor to setup states for this particular that
+ // is being broadcasted
+ def initializeVariable (variableInfo: VariableInfo) {
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ // listenPort should already be valid
+ assert (listenPort != -1)
+
+ pqOfSources = new PriorityQueue[SourceInfo]
+ val masterSource_0 =
+ new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0)
+ BroadcastCS.pqOfSources.add (masterSource_0)
+ // Add one more time to have two replicas of any seeds in the PQ
+ if (BroadcastCS.dualMode) {
+ val masterSource_1 =
+ new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1)
+ BroadcastCS.pqOfSources.add (masterSource_1)
+ }
+ }
+
+ def masterHostAddress = masterHostAddress_
+ def masterListenPort = masterListenPort_
+ def blockSize = blockSize_
+ def maxRetryCount = maxRetryCount_
+ def serverSocketTimout = serverSocketTimout_
+ def dualMode = dualMode_
+
+ def isMaster = isMaster_
+
+ def receiveBroadcast (variableUUID: UUID): Array[Byte] = {
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ // NO need to wait; ServeMultipleRequests is created much further ahead
+ while (listenPort == -1) {
+ listenPortLock.synchronized {
+ listenPortLock.wait
+ }
+ }
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = BroadcastCS.maxRetryCount
+ var retByteArray: Array[Byte] = null
+ do {
+ // Connect to Master and send this worker's Information
+ val clientSocketToMaster =
+ new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort)
+ logInfo("Connected to Master's guiding object")
+ // TODO: Guiding object connection is reusable
+ val oisMaster =
+ new ObjectInputStream (clientSocketToMaster.getInputStream)
+ val oosMaster =
+ new ObjectOutputStream (clientSocketToMaster.getOutputStream)
+
+ oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0))
+ oosMaster.flush
+
+ // Receive source information from Master
+ var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
+ totalBlocksLock.synchronized {
+ totalBlocksLock.notifyAll
+ }
+ totalBytes = sourceInfo.totalBytes
+ logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)
+
+ retByteArray = receiveSingleTransmission (sourceInfo)
+
+ logInfo("I got this from receiveSingleTransmission: " + retByteArray)
+
+ // TODO: Update sourceInfo to add error notifactions for Master
+ if (retByteArray == null) { sourceInfo.receptionFailed = true }
+
+ // TODO: Supposed to update values here, but we don't support advanced
+ // statistics right now. Master can handle leecherCount by itself.
+
+ // Send back statistics to the Master
+ oosMaster.writeObject (sourceInfo)
+
+ oisMaster.close
+ oosMaster.close
+ clientSocketToMaster.close
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && retByteArray == null)
+
+ return retByteArray
+ }
+
+ // Tries to receive broadcast from the Master and returns Boolean status.
+ // This might be called multiple times to retry a defined number of times.
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = {
+ var clientSocketToSource: Socket = null
+ var oisSource: ObjectInputStream = null
+ var oosSource: ObjectOutputStream = null
+
+ var retByteArray:Array[Byte] = null
+
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource =
+ new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource =
+ new ObjectOutputStream (clientSocketToSource.getOutputStream)
+ oisSource =
+ new ObjectInputStream (clientSocketToSource.getInputStream)
+
+ logInfo("Inside receiveSingleTransmission")
+ logInfo("totalBlocks: " + totalBlocks + " " + "hasBlocks: " + hasBlocks)
+ retByteArray = new Array[Byte] (totalBytes)
+ for (i <- 0 until totalBlocks) {
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ System.arraycopy (bcBlock.byteArray, 0, retByteArray,
+ i * BroadcastCS.blockSize, bcBlock.byteArray.length)
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+ hasBlocksLock.synchronized {
+ hasBlocksLock.notifyAll
+ }
+ logInfo("Received block: " + i + " " + bcBlock)
+ }
+ assert (hasBlocks == totalBlocks)
+ logInfo("After the receive loop")
+ } catch {
+ case e: Exception => {
+ retByteArray = null
+ logInfo("receiveSingleTransmission had a " + e)
+ }
+ } finally {
+ if (oisSource != null) { oisSource.close }
+ if (oosSource != null) { oosSource.close }
+ if (clientSocketToSource != null) { clientSocketToSource.close }
+ }
+
+ return retByteArray
+ }
+
+// class TrackMultipleValues extends Thread with Logging {
+// override def run = {
+// var threadPool = Executors.newCachedThreadPool
+// var serverSocket: ServerSocket = null
+//
+// serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+// logInfo("TrackMultipleVariables" + serverSocket + " " + listenPort)
+//
+// var keepAccepting = true
+// try {
+// while (keepAccepting) {
+// var clientSocket: Socket = null
+// try {
+// serverSocket.setSoTimeout (serverSocketTimout)
+// clientSocket = serverSocket.accept
+// } catch {
+// case e: Exception => {
+// logInfo("TrackMultipleValues Timeout. Stopping listening...")
+// keepAccepting = false
+// }
+// }
+// logInfo("TrackMultipleValues:Got new request:" + clientSocket)
+// if (clientSocket != null) {
+// try {
+// threadPool.execute (new Runnable {
+// def run = {
+// val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+// val ois = new ObjectInputStream (clientSocket.getInputStream)
+// try {
+// val variableUUID = ois.readObject.asInstanceOf[UUID]
+// var contactPort = 0
+// // TODO: Add logic and data structures to find out UUID->port
+// // mapping. 0 = missed the broadcast, read from HDFS; <0 =
+// // Haven't started yet, wait & retry; >0 = Read from this port
+// oos.writeObject (contactPort)
+// } catch {
+// case e: Exception => { }
+// } finally {
+// ois.close
+// oos.close
+// clientSocket.close
+// }
+// }
+// })
+// } catch {
+// // In failure, close the socket here; else, the thread will close it
+// case ioe: IOException => clientSocket.close
+// }
+// }
+// }
+// } finally {
+// serverSocket.close
+// }
+// }
+// }
+//
+// class TrackSingleValue {
+//
+// }
+
+// public static ExecutorService newCachedThreadPool() {
+// return new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60L, TimeUnit.SECONDS,
+// new SynchronousQueue<Runnable>());
+// }
+
+
+ class GuideMultipleRequests extends Thread with Logging {
+ override def run = {
+ var threadPool = Executors.newCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+ // listenPort = BroadcastCS.masterListenPort
+ logInfo("GuideMultipleRequests" + serverSocket + " " + listenPort)
+
+ var keepAccepting = true
+ try {
+ while (keepAccepting) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (serverSocketTimout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("GuideMultipleRequests Timeout. Stopping listening...")
+ keepAccepting = false
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Guide:Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute (new GuideSingleRequest (clientSocket))
+ } catch {
+ // In failure, close the socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+ }
+
+ class GuideSingleRequest (val clientSocket: Socket)
+ extends Runnable with Logging {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ def run = {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. ReplicaID is 0 and other fields are invalid (-1)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource (sourceInfo)
+ logInfo("Sending selectedSourceInfo:" + selectedSourceInfo)
+ oos.writeObject (selectedSourceInfo)
+ oos.flush
+
+ // Add this new (if it can finish) source to the PQ of sources
+ thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes, 0)
+ logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo)
+ pqOfSources.synchronized {
+ pqOfSources.add (thisWorkerInfo)
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in pqOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ pqOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert (pqOfSources.contains (selectedSourceInfo))
+
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // TODO: Removing a source based on just one failure notification!
+ // Update leecher count and put it back in IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add (selectedSourceInfo)
+
+ // No need to find and update thisWorkerInfo, but add its replica
+ if (BroadcastCS.dualMode) {
+ pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress,
+ thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1))
+ }
+ }
+ }
+ } catch {
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ // Assuming that exception caused due to receiver worker failure
+ // Remove failed worker from pqOfSources and update leecherCount of
+ // corresponding source worker
+ pqOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ pqOfSources.remove (selectedSourceInfo)
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ pqOfSources.add (selectedSourceInfo)
+ }
+
+ // Remove thisWorkerInfo
+ if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) }
+ }
+ }
+ } finally {
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ // TODO: If a worker fails to get the broadcasted variable from a source and
+ // comes back to Master, this function might choose the worker itself as a
+ // source tp create a dependency cycle (this worker was put into pqOfSources
+ // as a streming source when it first arrived). The length of this cycle can
+ // be arbitrarily long.
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ // Select one with the lowest number of leechers
+ pqOfSources.synchronized {
+ // take is a blocking call removing the element from PQ
+ var selectedSource = pqOfSources.poll
+ assert (selectedSource != null)
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+ // Add it back and then return
+ pqOfSources.add (selectedSource)
+ return selectedSource
+ }
+ }
+ }
+ }
+
+ class ServeMultipleRequests extends Thread with Logging {
+ override def run = {
+ var threadPool = Executors.newCachedThreadPool
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket (0)
+ listenPort = serverSocket.getLocalPort
+ logInfo("ServeMultipleRequests" + serverSocket + " " + listenPort)
+
+ listenPortLock.synchronized {
+ listenPortLock.notifyAll
+ }
+
+ var keepAccepting = true
+ try {
+ while (keepAccepting) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout (serverSocketTimout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ logInfo("ServeMultipleRequests Timeout. Stopping listening...")
+ keepAccepting = false
+ }
+ }
+ if (clientSocket != null) {
+ logInfo("Serve:Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute (new ServeSingleRequest (clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close
+ }
+ }
+ }
+ } finally {
+ serverSocket.close
+ }
+ }
+
+ class ServeSingleRequest (val clientSocket: Socket)
+ extends Runnable with Logging {
+ private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+ private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+ def run = {
+ try {
+ logInfo("new ServeSingleRequest is running")
+ sendObject
+ } catch {
+ // TODO: Need to add better exception handling here
+ // If something went wrong, e.g., the worker at the other end died etc.
+ // then close everything up
+ case e: Exception => {
+ logInfo("ServeSingleRequest had a " + e)
+ }
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close
+ oos.close
+ clientSocket.close
+ }
+ }
+
+ private def sendObject = {
+ // Wait till receiving the SourceInfo from Master
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized {
+ totalBlocksLock.wait
+ }
+ }
+
+ for (i <- 0 until totalBlocks) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized {
+ hasBlocksLock.wait
+ }
+ }
+ try {
+ oos.writeObject (arrayOfBlocks(i))
+ oos.flush
+ } catch {
+ case e: Exception => { }
+ }
+ logInfo("Send block: " + i + " " + arrayOfBlocks(i))
+ }
+ }
+ }
+ }
+}
+
+private object BroadcastCH extends Logging {
+ val values = Cache.newKeySpace()
+
+ private var initialized = false
+
+ private var fileSystem: FileSystem = null
+ private var workDir: String = null
+ private var compress: Boolean = false
+ private var bufferSize: Int = 65536
+
+ def initialize () {
+ synchronized {
+ if (!initialized) {
+ bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val dfs = System.getProperty("spark.dfs", "file:///")
+ if (!dfs.startsWith("file://")) {
+ val conf = new Configuration()
+ conf.setInt("io.file.buffer.size", bufferSize)
+ val rep = System.getProperty("spark.dfs.replication", "3").toInt
+ conf.setInt("dfs.replication", rep)
+ fileSystem = FileSystem.get(new URI(dfs), conf)
+ }
+ workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+ compress = System.getProperty("spark.compress", "false").toBoolean
+
+ initialized = true
+ }
+ }
+ }
+
+ private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
+
+ def openFileForReading(uuid: UUID): InputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.open(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileInputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFInputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedInputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+
+ def openFileForWriting(uuid: UUID): OutputStream = {
+ val fileStream = if (fileSystem != null) {
+ fileSystem.create(getPath(uuid))
+ } else {
+ // Local filesystem
+ new FileOutputStream(getPath(uuid).toString)
+ }
+ if (compress)
+ new LZFOutputStream(fileStream) // LZF stream does its own buffering
+ else if (fileSystem == null)
+ new BufferedOutputStream(fileStream, bufferSize)
+ else
+ fileStream // Hadoop streams do their own buffering
+ }
+}
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
new file mode 100644
index 0000000000..9887520758
--- /dev/null
+++ b/core/src/main/scala/spark/Cache.scala
@@ -0,0 +1,63 @@
+package spark
+
+import java.util.concurrent.atomic.AtomicLong
+
+
+/**
+ * An interface for caches in Spark, to allow for multiple implementations.
+ * Caches are used to store both partitions of cached RDDs and broadcast
+ * variables on Spark executors.
+ *
+ * A single Cache instance gets created on each machine and is shared by all
+ * caches (i.e. both the RDD split cache and the broadcast variable cache),
+ * to enable global replacement policies. However, because these several
+ * independent modules all perform caching, it is important to give them
+ * separate key namespaces, so that an RDD and a broadcast variable (for
+ * example) do not use the same key. For this purpose, Cache has the
+ * notion of KeySpaces. Each client module must first ask for a KeySpace,
+ * and then call get() and put() on that space using its own keys.
+ * This abstract class handles the creation of key spaces, so that subclasses
+ * need only deal with keys that are unique across modules.
+ */
+abstract class Cache {
+ private val nextKeySpaceId = new AtomicLong(0)
+ private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
+
+ def newKeySpace() = new KeySpace(this, newKeySpaceId())
+
+ def get(key: Any): Any
+ def put(key: Any, value: Any): Unit
+}
+
+
+/**
+ * A key namespace in a Cache.
+ */
+class KeySpace(cache: Cache, id: Long) {
+ def get(key: Any): Any = cache.get((id, key))
+ def put(key: Any, value: Any): Unit = cache.put((id, key), value)
+}
+
+
+/**
+ * The Cache object maintains a global Cache instance, of the type specified
+ * by the spark.cache.class property.
+ */
+object Cache {
+ private var instance: Cache = null
+
+ def initialize() {
+ val cacheClass = System.getProperty("spark.cache.class",
+ "spark.SoftReferenceCache")
+ instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
+ }
+
+ def getInstance(): Cache = {
+ if (instance == null) {
+ throw new SparkException("Cache.getInstance called before initialize")
+ }
+ instance
+ }
+
+ def newKeySpace(): KeySpace = getInstance().newKeySpace()
+}
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
new file mode 100644
index 0000000000..0e0b3954d4
--- /dev/null
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -0,0 +1,159 @@
+package spark
+
+import scala.collection.mutable.Map
+import scala.collection.mutable.Set
+
+import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
+import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.Opcodes._
+
+
+object ClosureCleaner extends Logging {
+ private def getClassReader(cls: Class[_]): ClassReader = {
+ new ClassReader(cls.getResourceAsStream(
+ cls.getName.replaceFirst("^.*\\.", "") + ".class"))
+ }
+
+ private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ return f.getType :: getOuterClasses(f.get(obj))
+ }
+ return Nil
+ }
+
+ private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ return f.get(obj) :: getOuterObjects(f.get(obj))
+ }
+ return Nil
+ }
+
+ private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
+ val seen = Set[Class[_]](obj.getClass)
+ var stack = List[Class[_]](obj.getClass)
+ while (!stack.isEmpty) {
+ val cr = getClassReader(stack.head)
+ stack = stack.tail
+ val set = Set[Class[_]]()
+ cr.accept(new InnerClosureFinder(set), 0)
+ for (cls <- set -- seen) {
+ seen += cls
+ stack = cls :: stack
+ }
+ }
+ return (seen - obj.getClass).toList
+ }
+
+ private def createNullValue(cls: Class[_]): AnyRef = {
+ if (cls.isPrimitive)
+ new java.lang.Byte(0: Byte) // Should be convertible to any primitive type
+ else
+ null
+ }
+
+ def clean(func: AnyRef): Unit = {
+ // TODO: cache outerClasses / innerClasses / accessedFields
+ val outerClasses = getOuterClasses(func)
+ val innerClasses = getInnerClasses(func)
+ val outerObjects = getOuterObjects(func)
+
+ val accessedFields = Map[Class[_], Set[String]]()
+ for (cls <- outerClasses)
+ accessedFields(cls) = Set[String]()
+ for (cls <- func.getClass :: innerClasses)
+ getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
+
+ var outer: AnyRef = null
+ for ((cls, obj) <- (outerClasses zip outerObjects).reverse) {
+ outer = instantiateClass(cls, outer);
+ for (fieldName <- accessedFields(cls)) {
+ val field = cls.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ val value = field.get(obj)
+ //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
+ field.set(outer, value)
+ }
+ }
+
+ if (outer != null) {
+ //logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
+ val field = func.getClass.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(func, outer)
+ }
+ }
+
+ private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = {
+ if (spark.repl.Main.interp == null) {
+ // This is a bona fide closure class, whose constructor has no effects
+ // other than to set its fields, so use its constructor
+ val cons = cls.getConstructors()(0)
+ val params = cons.getParameterTypes.map(createNullValue).toArray
+ if (outer != null)
+ params(0) = outer // First param is always outer object
+ return cons.newInstance(params: _*).asInstanceOf[AnyRef]
+ } else {
+ // Use reflection to instantiate object without calling constructor
+ val rf = sun.reflect.ReflectionFactory.getReflectionFactory();
+ val parentCtor = classOf[java.lang.Object].getDeclaredConstructor();
+ val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
+ val obj = newCtor.newInstance().asInstanceOf[AnyRef];
+ if (outer != null) {
+ //logInfo("3: Setting $outer on " + cls + " to " + outer);
+ val field = cls.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(obj, outer)
+ }
+ return obj
+ }
+ }
+}
+
+
+class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new EmptyVisitor {
+ override def visitFieldInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ if (op == GETFIELD)
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
+ output(cl) += name
+ }
+
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer"))
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
+ output(cl) += name
+ }
+ }
+ }
+}
+
+
+class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+ var myName: String = null
+
+ override def visit(version: Int, access: Int, name: String, sig: String,
+ superName: String, interfaces: Array[String]) {
+ myName = name
+ }
+
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new EmptyVisitor {
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ val argTypes = Type.getArgumentTypes(desc)
+ if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
+ && argTypes(0).toString.startsWith("L") // is it an object?
+ && argTypes(0).getInternalName == myName)
+ output += Class.forName(owner.replace('/', '.'), false,
+ Thread.currentThread.getContextClassLoader)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/DfsShuffle.scala b/core/src/main/scala/spark/DfsShuffle.scala
new file mode 100644
index 0000000000..7a42bf2d06
--- /dev/null
+++ b/core/src/main/scala/spark/DfsShuffle.scala
@@ -0,0 +1,120 @@
+package spark
+
+import java.io.{EOFException, ObjectInputStream, ObjectOutputStream}
+import java.net.URI
+import java.util.UUID
+
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+
+/**
+ * A simple implementation of shuffle using a distributed file system.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val dir = DfsShuffle.newTempDirectory()
+ logInfo("Intermediate data directory: " + dir)
+
+ val numberedSplitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = numberedSplitRdd.splits.size
+
+ // Run a parallel foreach to write the intermediate data files
+ numberedSplitRdd.foreach((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+ val fs = DfsShuffle.getFileSystem()
+ for (i <- 0 until numOutputSplits) {
+ val path = new Path(dir, "%d-to-%d".format(myIndex, i))
+ val out = new ObjectOutputStream(fs.create(path, true))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ }
+ })
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myIndex: Int) => {
+ val combiners = new HashMap[K, C]
+ val fs = DfsShuffle.getFileSystem()
+ for (i <- Utils.shuffle(0 until numInputSplits)) {
+ val path = new Path(dir, "%d-to-%d".format(i, myIndex))
+ val inputStream = new ObjectInputStream(fs.open(path))
+ try {
+ while (true) {
+ val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => {}
+ }
+ inputStream.close()
+ }
+ combiners
+ })
+ }
+}
+
+
+/**
+ * Companion object of DfsShuffle; responsible for initializing a Hadoop
+ * FileSystem object based on the spark.dfs property and generating names
+ * for temporary directories.
+ */
+object DfsShuffle {
+ private var initialized = false
+ private var fileSystem: FileSystem = null
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val dfs = System.getProperty("spark.dfs", "file:///")
+ val conf = new Configuration()
+ conf.setInt("io.file.buffer.size", bufferSize)
+ conf.setInt("dfs.replication", 1)
+ fileSystem = FileSystem.get(new URI(dfs), conf)
+ initialized = true
+ }
+ }
+
+ def getFileSystem(): FileSystem = {
+ initializeIfNeeded()
+ return fileSystem
+ }
+
+ def newTempDirectory(): String = {
+ val fs = getFileSystem()
+ val workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+ val uuid = UUID.randomUUID()
+ val path = workDir + "/shuffle-" + uuid
+ fs.mkdirs(new Path(path))
+ return path
+ }
+}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
new file mode 100644
index 0000000000..b4d023b428
--- /dev/null
+++ b/core/src/main/scala/spark/Executor.scala
@@ -0,0 +1,116 @@
+package spark
+
+import java.io.{File, FileOutputStream}
+import java.net.{URI, URL, URLClassLoader}
+import java.util.concurrent.{Executors, ExecutorService}
+
+import scala.collection.mutable.ArrayBuffer
+
+import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
+import mesos.{TaskDescription, TaskState, TaskStatus}
+
+/**
+ * The Mesos executor for Spark.
+ */
+class Executor extends mesos.Executor with Logging {
+ var classLoader: ClassLoader = null
+ var threadPool: ExecutorService = null
+
+ override def init(d: ExecutorDriver, args: ExecutorArgs) {
+ // Read spark.* system properties from executor arg
+ val props = Utils.deserialize[Array[(String, String)]](args.getData)
+ for ((key, value) <- props)
+ System.setProperty(key, value)
+
+ // Initialize cache and broadcast system (uses some properties read above)
+ Cache.initialize()
+ Broadcast.initialize(false)
+
+ // Create our ClassLoader (using spark properties) and set it on this thread
+ classLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(classLoader)
+
+ // Start worker thread pool (they will inherit our context ClassLoader)
+ threadPool = Executors.newCachedThreadPool()
+ }
+
+ override def launchTask(d: ExecutorDriver, desc: TaskDescription) {
+ // Pull taskId and arg out of TaskDescription because it won't be a
+ // valid pointer after this method call (TODO: fix this in C++/SWIG)
+ val taskId = desc.getTaskId
+ val arg = desc.getArg
+ threadPool.execute(new Runnable() {
+ def run() = {
+ logInfo("Running task ID " + taskId)
+ try {
+ Accumulators.clear
+ val task = Utils.deserialize[Task[Any]](arg, classLoader)
+ val value = task.run
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates)
+ d.sendStatusUpdate(new TaskStatus(
+ taskId, TaskState.TASK_FINISHED, Utils.serialize(result)))
+ logInfo("Finished task ID " + taskId)
+ } catch {
+ case e: Exception => {
+ // TODO: Handle errors in tasks less dramatically
+ logError("Exception in task ID " + taskId, e)
+ System.exit(1)
+ }
+ }
+ }
+ })
+ }
+
+ // Create a ClassLoader for use in tasks, adding any JARs specified by the
+ // user or any classes created by the interpreter to the search path
+ private def createClassLoader(): ClassLoader = {
+ var loader = this.getClass.getClassLoader
+
+ // If any JAR URIs are given through spark.jar.uris, fetch them to the
+ // current directory and put them all on the classpath. We assume that
+ // each URL has a unique file name so that no local filenames will clash
+ // in this process. This is guaranteed by MesosScheduler.
+ val uris = System.getProperty("spark.jar.uris", "")
+ val localFiles = ArrayBuffer[String]()
+ for (uri <- uris.split(",").filter(_.size > 0)) {
+ val url = new URL(uri)
+ val filename = url.getPath.split("/").last
+ downloadFile(url, filename)
+ localFiles += filename
+ }
+ if (localFiles.size > 0) {
+ val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
+ loader = new URLClassLoader(urls, loader)
+ }
+
+ // If the REPL is in use, add another ClassLoader that will read
+ // new classes defined by the REPL as the user types code
+ val classUri = System.getProperty("spark.repl.class.uri")
+ if (classUri != null) {
+ logInfo("Using REPL class URI: " + classUri)
+ loader = new repl.ExecutorClassLoader(classUri, loader)
+ }
+
+ return loader
+ }
+
+ // Download a file from a given URL to the local filesystem
+ private def downloadFile(url: URL, localPath: String) {
+ val in = url.openStream()
+ val out = new FileOutputStream(localPath)
+ Utils.copyStream(in, out, true)
+ }
+}
+
+/**
+ * Executor entry point.
+ */
+object Executor extends Logging {
+ def main(args: Array[String]) {
+ System.loadLibrary("mesos")
+ // Create a new Executor and start it running
+ val exec = new Executor
+ new MesosExecutorDriver(exec).run()
+ }
+}
diff --git a/core/src/main/scala/spark/HadoopFile.scala b/core/src/main/scala/spark/HadoopFile.scala
new file mode 100644
index 0000000000..a63c9d8a94
--- /dev/null
+++ b/core/src/main/scala/spark/HadoopFile.scala
@@ -0,0 +1,118 @@
+package spark
+
+import mesos.SlaveOffer
+
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.InputSplit
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapred.RecordReader
+import org.apache.hadoop.mapred.Reporter
+import org.apache.hadoop.util.ReflectionUtils
+
+/** A Spark split class that wraps around a Hadoop InputSplit */
+@serializable class HadoopSplit(@transient s: InputSplit)
+extends Split {
+ val inputSplit = new SerializableWritable[InputSplit](s)
+
+ // Hadoop gives each split a unique toString value, so use this as our ID
+ override def getId() = "HadoopSplit(" + inputSplit.toString + ")"
+}
+
+
+/**
+ * An RDD that reads a Hadoop file (from HDFS, S3, the local filesystem, etc)
+ * and represents it as a set of key-value pairs using a given InputFormat.
+ */
+class HadoopFile[K, V](
+ sc: SparkContext,
+ path: String,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V])
+extends RDD[(K, V)](sc) {
+ @transient val splits_ : Array[Split] = ConfigureLock.synchronized {
+ val conf = new JobConf()
+ FileInputFormat.setInputPaths(conf, path)
+ val inputFormat = createInputFormat(conf)
+ val inputSplits = inputFormat.getSplits(conf, sc.numCores)
+ inputSplits.map(x => new HadoopSplit(x): Split).toArray
+ }
+
+ def createInputFormat(conf: JobConf): InputFormat[K, V] = {
+ ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
+ .asInstanceOf[InputFormat[K, V]]
+ }
+
+ override def splits = splits_
+
+ override def iterator(theSplit: Split) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopSplit]
+ var reader: RecordReader[K, V] = null
+
+ ConfigureLock.synchronized {
+ val conf = new JobConf()
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ val fmt = createInputFormat(conf)
+ reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
+ }
+
+ val key: K = keyClass.newInstance()
+ val value: V = valueClass.newInstance()
+ var gotNext = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!gotNext) {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eofe: java.io.EOFException =>
+ finished = true
+ }
+ gotNext = true
+ }
+ !finished
+ }
+
+ override def next: (K, V) = {
+ if (!gotNext) {
+ finished = !reader.next(key, value)
+ }
+ if (finished) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ (key, value)
+ }
+ }
+
+ override def preferredLocations(split: Split) = {
+ // TODO: Filtering out "localhost" in case of file:// URLs
+ val hadoopSplit = split.asInstanceOf[HadoopSplit]
+ hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ }
+}
+
+
+/**
+ * Convenience class for Hadoop files read using TextInputFormat that
+ * represents the file as an RDD of Strings.
+ */
+class HadoopTextFile(sc: SparkContext, path: String)
+extends MappedRDD[String, (LongWritable, Text)](
+ new HadoopFile(sc, path, classOf[TextInputFormat],
+ classOf[LongWritable], classOf[Text]),
+ { pair: (LongWritable, Text) => pair._2.toString }
+)
+
+
+/**
+ * Object used to ensure that only one thread at a time is configuring Hadoop
+ * InputFormat classes. Apparently configuring them is not thread safe!
+ */
+object ConfigureLock {}
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
new file mode 100644
index 0000000000..d2a663ac1f
--- /dev/null
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -0,0 +1,67 @@
+package spark
+
+import java.io.File
+import java.net.InetAddress
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.handler.DefaultHandler
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+
+
+/**
+ * Exception type thrown by HttpServer when it is in the wrong state
+ * for an operation.
+ */
+class ServerStateException(message: String) extends Exception(message)
+
+
+/**
+ * An HTTP server for static content used to allow worker nodes to access JARs
+ * added to SparkContext as well as classes created by the interpreter when
+ * the user types in code. This is just a wrapper around a Jetty server.
+ */
+class HttpServer(resourceBase: File) extends Logging {
+ private var server: Server = null
+ private var port: Int = -1
+
+ def start() {
+ if (server != null) {
+ throw new ServerStateException("Server is already started")
+ } else {
+ server = new Server(0)
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val resHandler = new ResourceHandler
+ resHandler.setResourceBase(resourceBase.getAbsolutePath)
+ val handlerList = new HandlerList
+ handlerList.setHandlers(Array(resHandler, new DefaultHandler))
+ server.setHandler(handlerList)
+ server.start()
+ port = server.getConnectors()(0).getLocalPort()
+ }
+ }
+
+ def stop() {
+ if (server == null) {
+ throw new ServerStateException("Server is already stopped")
+ } else {
+ server.stop()
+ port = -1
+ server = null
+ }
+ }
+
+ /**
+ * Get the URI of this HTTP server (http://host:port)
+ */
+ def uri: String = {
+ if (server == null) {
+ throw new ServerStateException("Server is not started")
+ } else {
+ return "http://" + Utils.localIpAddress + ":" + port
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala
new file mode 100644
index 0000000000..6abbcbce51
--- /dev/null
+++ b/core/src/main/scala/spark/Job.scala
@@ -0,0 +1,18 @@
+package spark
+
+import mesos._
+
+/**
+ * Class representing a parallel job in MesosScheduler. Schedules the
+ * job by implementing various callbacks.
+ */
+abstract class Job(jobId: Int) {
+ def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int)
+ : Option[TaskDescription]
+
+ def statusUpdate(t: TaskStatus): Unit
+
+ def error(code: Int, message: String): Unit
+
+ def getId(): Int = jobId
+}
diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala
new file mode 100644
index 0000000000..367599cfb4
--- /dev/null
+++ b/core/src/main/scala/spark/LocalFileShuffle.scala
@@ -0,0 +1,171 @@
+package spark
+
+import java.io._
+import java.net.URL
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+
+/**
+ * A simple implementation of shuffle using local files served through HTTP.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = LocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+ for (i <- 0 until numOutputSplits) {
+ val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ }
+ (myIndex, LocalFileShuffle.serverUri)
+ }).collect()
+
+ // Build a hashmap from server URI to list of splits (to facillitate
+ // fetching all the URIs on a server within a single connection)
+ val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
+ for ((inputId, serverUri) <- outputLocs) {
+ splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
+ }
+
+ // TODO: Could broadcast splitsByUri
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ val combiners = new HashMap[K, C]
+ for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
+ for (i <- inputIds) {
+ val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
+ val inputStream = new ObjectInputStream(new URL(url).openStream())
+ try {
+ while (true) {
+ val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => {}
+ }
+ inputStream.close()
+ }
+ }
+ combiners
+ })
+ }
+}
+
+
+object LocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+ private var server: HttpServer = null
+ private var serverUri: String = null
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID()
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists()) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+ val extServerPort = System.getProperty(
+ "spark.localFileShuffle.external.server.port", "-1").toInt
+ if (extServerPort != -1) {
+ // We're using an external HTTP server; set URI relative to its root
+ var extServerPath = System.getProperty(
+ "spark.localFileShuffle.external.server.path", "")
+ if (extServerPath != "" && !extServerPath.endsWith("/")) {
+ extServerPath += "/"
+ }
+ serverUri = "http://%s:%d/%s/spark-local-%s".format(
+ Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
+ } else {
+ // Create our own server
+ server = new HttpServer(localDir)
+ server.start()
+ serverUri = server.uri
+ }
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def getServerUri(): String = {
+ initializeIfNeeded()
+ serverUri
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+}
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala
new file mode 100644
index 0000000000..20954a1224
--- /dev/null
+++ b/core/src/main/scala/spark/LocalScheduler.scala
@@ -0,0 +1,72 @@
+package spark
+
+import java.util.concurrent._
+
+import scala.collection.mutable.Map
+
+/**
+ * A simple Scheduler implementation that runs tasks locally in a thread pool.
+ */
+private class LocalScheduler(threads: Int) extends Scheduler with Logging {
+ var threadPool: ExecutorService =
+ Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+
+ override def start() {}
+
+ override def waitForRegister() {}
+
+ override def runTasks[T](tasks: Array[Task[T]])(implicit m: ClassManifest[T])
+ : Array[T] = {
+ val futures = new Array[Future[TaskResult[T]]](tasks.length)
+
+ for (i <- 0 until tasks.length) {
+ futures(i) = threadPool.submit(new Callable[TaskResult[T]]() {
+ def call(): TaskResult[T] = {
+ logInfo("Running task " + i)
+ try {
+ // Serialize and deserialize the task so that accumulators are
+ // changed to thread-local ones; this adds a bit of unnecessary
+ // overhead but matches how the Nexus Executor works
+ Accumulators.clear
+ val bytes = Utils.serialize(tasks(i))
+ logInfo("Size of task " + i + " is " + bytes.size + " bytes")
+ val task = Utils.deserialize[Task[T]](
+ bytes, currentThread.getContextClassLoader)
+ val value = task.run
+ val accumUpdates = Accumulators.values
+ logInfo("Finished task " + i)
+ new TaskResult[T](value, accumUpdates)
+ } catch {
+ case e: Exception => {
+ // TODO: Do something nicer here
+ logError("Exception in task " + i, e)
+ System.exit(1)
+ null
+ }
+ }
+ }
+ })
+ }
+
+ val taskResults = futures.map(_.get)
+ for (result <- taskResults)
+ Accumulators.add(currentThread, result.accumUpdates)
+ return taskResults.map(_.value).toArray(m)
+ }
+
+ override def stop() {}
+
+ override def numCores() = threads
+}
+
+
+/**
+ * A ThreadFactory that creates daemon threads
+ */
+private object DaemonThreadFactory extends ThreadFactory {
+ override def newThread(r: Runnable): Thread = {
+ val t = new Thread(r);
+ t.setDaemon(true)
+ return t
+ }
+}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
new file mode 100644
index 0000000000..2d1feebbb1
--- /dev/null
+++ b/core/src/main/scala/spark/Logging.scala
@@ -0,0 +1,49 @@
+package spark
+
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+
+/**
+ * Utility trait for classes that want to log data. Creates a SLF4J logger
+ * for the class and allows logging messages at different levels using
+ * methods that only evaluate parameters lazily if the log level is enabled.
+ */
+trait Logging {
+ // Make the log field transient so that objects with Logging can
+ // be serialized and used on another machine
+ @transient private var log_ : Logger = null
+
+ // Method to get or create the logger for this object
+ def log: Logger = {
+ if (log_ == null) {
+ var className = this.getClass().getName()
+ // Ignore trailing $'s in the class names for Scala objects
+ if (className.endsWith("$"))
+ className = className.substring(0, className.length - 1)
+ log_ = LoggerFactory.getLogger(className)
+ }
+ return log_
+ }
+
+ // Log methods that take only a String
+ def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
+
+ def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
+
+ def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
+
+ def logError(msg: => String) = if (log.isErrorEnabled) log.error(msg)
+
+ // Log methods that take Throwables (Exceptions/Errors) too
+ def logInfo(msg: => String, throwable: Throwable) =
+ if (log.isInfoEnabled) log.info(msg)
+
+ def logDebug(msg: => String, throwable: Throwable) =
+ if (log.isDebugEnabled) log.debug(msg)
+
+ def logWarning(msg: => String, throwable: Throwable) =
+ if (log.isWarnEnabled) log.warn(msg, throwable)
+
+ def logError(msg: => String, throwable: Throwable) =
+ if (log.isErrorEnabled) log.error(msg, throwable)
+}
diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala
new file mode 100644
index 0000000000..c45eff64d4
--- /dev/null
+++ b/core/src/main/scala/spark/MesosScheduler.scala
@@ -0,0 +1,294 @@
+package spark
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Map
+import scala.collection.mutable.Queue
+import scala.collection.JavaConversions._
+
+import mesos.{Scheduler => MScheduler}
+import mesos._
+
+/**
+ * The main Scheduler implementation, which runs jobs on Mesos. Clients should
+ * first call start(), then submit tasks through the runTasks method.
+ */
+private class MesosScheduler(
+ sc: SparkContext, master: String, frameworkName: String)
+extends MScheduler with spark.Scheduler with Logging
+{
+ // Environment variables to pass to our executors
+ val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
+ "SPARK_MEM",
+ "SPARK_CLASSPATH",
+ "SPARK_LIBRARY_PATH"
+ )
+
+ // Lock used to wait for scheduler to be registered
+ private var isRegistered = false
+ private val registeredLock = new Object()
+
+ private var activeJobs = new HashMap[Int, Job]
+ private var activeJobsQueue = new Queue[Job]
+
+ private var taskIdToJobId = new HashMap[Int, Int]
+ private var jobTasks = new HashMap[Int, HashSet[Int]]
+
+ // Incrementing job and task IDs
+ private var nextJobId = 0
+ private var nextTaskId = 0
+
+ // Driver for talking to Mesos
+ var driver: SchedulerDriver = null
+
+ // JAR server, if any JARs were added by the user to the SparkContext
+ var jarServer: HttpServer = null
+
+ // URIs of JARs to pass to executor
+ var jarUris: String = ""
+
+ def newJobId(): Int = this.synchronized {
+ val id = nextJobId
+ nextJobId += 1
+ return id
+ }
+
+ def newTaskId(): Int = {
+ val id = nextTaskId;
+ nextTaskId += 1;
+ return id
+ }
+
+ override def start() {
+ if (sc.jars.size > 0) {
+ // If the user added any JARS to the SparkContext, create an HTTP server
+ // to serve them to our executors
+ createJarServer()
+ }
+ new Thread("Spark scheduler") {
+ setDaemon(true)
+ override def run {
+ val sched = MesosScheduler.this
+ sched.driver = new MesosSchedulerDriver(sched, master)
+ sched.driver.run()
+ }
+ }.start
+ }
+
+ override def getFrameworkName(d: SchedulerDriver): String = frameworkName
+
+ override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = {
+ val sparkHome = sc.getSparkHome match {
+ case Some(path) => path
+ case None =>
+ throw new SparkException("Spark home is not set; set it through the " +
+ "spark.home system property, the SPARK_HOME environment variable " +
+ "or the SparkContext constructor")
+ }
+ val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
+ val params = new JHashMap[String, String]
+ for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
+ if (System.getenv(key) != null) {
+ params("env." + key) = System.getenv(key)
+ }
+ }
+ new ExecutorInfo(execScript, createExecArg())
+ }
+
+ /**
+ * The primary means to submit a job to the scheduler. Given a list of tasks,
+ * runs them and returns an array of the results.
+ */
+ override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
+ waitForRegister()
+ val jobId = newJobId()
+ val myJob = new SimpleJob(this, tasks, jobId)
+ try {
+ this.synchronized {
+ activeJobs(jobId) = myJob
+ activeJobsQueue += myJob
+ jobTasks(jobId) = new HashSet()
+ }
+ driver.reviveOffers();
+ return myJob.join();
+ } finally {
+ this.synchronized {
+ activeJobs -= jobId
+ activeJobsQueue.dequeueAll(x => (x == myJob))
+ taskIdToJobId --= jobTasks(jobId)
+ jobTasks.remove(jobId)
+ }
+ }
+ }
+
+ override def registered(d: SchedulerDriver, frameworkId: String) {
+ logInfo("Registered as framework ID " + frameworkId)
+ registeredLock.synchronized {
+ isRegistered = true
+ registeredLock.notifyAll()
+ }
+ }
+
+ override def waitForRegister() {
+ registeredLock.synchronized {
+ while (!isRegistered)
+ registeredLock.wait()
+ }
+ }
+
+ /**
+ * Method called by Mesos to offer resources on slaves. We resond by asking
+ * our active jobs for tasks in FIFO order. We fill each node with tasks in
+ * a round-robin manner so that tasks are balanced across the cluster.
+ */
+ override def resourceOffer(
+ d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) {
+ synchronized {
+ val tasks = new JArrayList[TaskDescription]
+ val availableCpus = offers.map(_.getParams.get("cpus").toInt)
+ val availableMem = offers.map(_.getParams.get("mem").toInt)
+ var launchedTask = false
+ for (job <- activeJobsQueue) {
+ do {
+ launchedTask = false
+ for (i <- 0 until offers.size.toInt) {
+ try {
+ job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match {
+ case Some(task) =>
+ tasks.add(task)
+ taskIdToJobId(task.getTaskId) = job.getId
+ jobTasks(job.getId) += task.getTaskId
+ availableCpus(i) -= task.getParams.get("cpus").toInt
+ availableMem(i) -= task.getParams.get("mem").toInt
+ launchedTask = true
+ case None => {}
+ }
+ } catch {
+ case e: Exception => logError("Exception in resourceOffer", e)
+ }
+ }
+ } while (launchedTask)
+ }
+ val params = new JHashMap[String, String]
+ params.put("timeout", "1")
+ d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout?
+ }
+ }
+
+ // Check whether a Mesos task state represents a finished task
+ def isFinished(state: TaskState) = {
+ state == TaskState.TASK_FINISHED ||
+ state == TaskState.TASK_FAILED ||
+ state == TaskState.TASK_KILLED ||
+ state == TaskState.TASK_LOST
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ synchronized {
+ try {
+ taskIdToJobId.get(status.getTaskId) match {
+ case Some(jobId) =>
+ if (activeJobs.contains(jobId)) {
+ activeJobs(jobId).statusUpdate(status)
+ }
+ if (isFinished(status.getState)) {
+ taskIdToJobId.remove(status.getTaskId)
+ jobTasks(jobId) -= status.getTaskId
+ }
+ case None =>
+ logInfo("TID " + status.getTaskId + " already finished")
+ }
+ } catch {
+ case e: Exception => logError("Exception in statusUpdate", e)
+ }
+ }
+ }
+
+ override def error(d: SchedulerDriver, code: Int, message: String) {
+ logError("Mesos error: %s (error code: %d)".format(message, code))
+ synchronized {
+ if (activeJobs.size > 0) {
+ // Have each job throw a SparkException with the error
+ for ((jobId, activeJob) <- activeJobs) {
+ try {
+ activeJob.error(code, message)
+ } catch {
+ case e: Exception => logError("Exception in error callback", e)
+ }
+ }
+ } else {
+ // No jobs are active but we still got an error. Just exit since this
+ // must mean the error is during registration.
+ // It might be good to do something smarter here in the future.
+ System.exit(1)
+ }
+ }
+ }
+
+ override def stop() {
+ if (driver != null) {
+ driver.stop()
+ }
+ if (jarServer != null) {
+ jarServer.stop()
+ }
+ }
+
+ // TODO: query Mesos for number of cores
+ override def numCores() =
+ System.getProperty("spark.default.parallelism", "2").toInt
+
+ // Create a server for all the JARs added by the user to SparkContext.
+ // We first copy the JARs to a temp directory for easier server setup.
+ private def createJarServer() {
+ val jarDir = Utils.createTempDir()
+ logInfo("Temp directory for JARs: " + jarDir)
+ val filenames = ArrayBuffer[String]()
+ // Copy each JAR to a unique filename in the jarDir
+ for ((path, index) <- sc.jars.zipWithIndex) {
+ val file = new File(path)
+ val filename = index + "_" + file.getName
+ copyFile(file, new File(jarDir, filename))
+ filenames += filename
+ }
+ // Create the server
+ jarServer = new HttpServer(jarDir)
+ jarServer.start()
+ // Build up the jar URI list
+ val serverUri = jarServer.uri
+ jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
+ logInfo("JAR server started at " + serverUri)
+ }
+
+ // Copy a file on the local file system
+ private def copyFile(source: File, dest: File) {
+ val in = new FileInputStream(source)
+ val out = new FileOutputStream(dest)
+ Utils.copyStream(in, out, true)
+ }
+
+ // Create and serialize the executor argument to pass to Mesos.
+ // Our executor arg is an array containing all the spark.* system properties
+ // in the form of (String, String) pairs.
+ private def createExecArg(): Array[Byte] = {
+ val props = new HashMap[String, String]
+ val iter = System.getProperties.entrySet.iterator
+ while (iter.hasNext) {
+ val entry = iter.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.")) {
+ props(key) = value
+ }
+ }
+ // Set spark.jar.uris to our JAR URIs, regardless of system property
+ props("spark.jar.uris") = jarUris
+ // Serialize the map as an array of (String, String) pairs
+ return Utils.serialize(props.toArray)
+ }
+}
diff --git a/core/src/main/scala/spark/NumberedSplitRDD.scala b/core/src/main/scala/spark/NumberedSplitRDD.scala
new file mode 100644
index 0000000000..7b12210d84
--- /dev/null
+++ b/core/src/main/scala/spark/NumberedSplitRDD.scala
@@ -0,0 +1,42 @@
+package spark
+
+import mesos.SlaveOffer
+
+
+/**
+ * An RDD that takes the splits of a parent RDD and gives them unique indexes.
+ * This is useful for a variety of shuffle implementations.
+ */
+class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
+extends RDD[(Int, Iterator[T])](prev.sparkContext) {
+ @transient val splits_ = {
+ prev.splits.zipWithIndex.map {
+ case (s, i) => new NumberedSplitRDDSplit(s, i): Split
+ }.toArray
+ }
+
+ override def splits = splits_
+
+ override def preferredLocations(split: Split) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ prev.preferredLocations(nsplit.prev)
+ }
+
+ override def iterator(split: Split) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ Iterator((nsplit.index, prev.iterator(nsplit.prev)))
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ prev.taskStarted(nsplit.prev, slot)
+ }
+}
+
+
+/**
+ * A split in a NumberedSplitRDD.
+ */
+class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
+ override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
+}
diff --git a/core/src/main/scala/spark/ParallelArray.scala b/core/src/main/scala/spark/ParallelArray.scala
new file mode 100644
index 0000000000..a01904d61c
--- /dev/null
+++ b/core/src/main/scala/spark/ParallelArray.scala
@@ -0,0 +1,76 @@
+package spark
+
+import mesos.SlaveOffer
+
+import java.util.concurrent.atomic.AtomicLong
+
+@serializable class ParallelArraySplit[T: ClassManifest](
+ val arrayId: Long, val slice: Int, values: Seq[T])
+extends Split {
+ def iterator(): Iterator[T] = values.iterator
+
+ override def hashCode(): Int = (41 * (41 + arrayId) + slice).toInt
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ParallelArraySplit[_] =>
+ (this.arrayId == that.arrayId && this.slice == that.slice)
+ case _ => false
+ }
+
+ override def getId() =
+ "ParallelArraySplit(arrayId %d, slice %d)".format(arrayId, slice)
+}
+
+class ParallelArray[T: ClassManifest](
+ sc: SparkContext, @transient data: Seq[T], numSlices: Int)
+extends RDD[T](sc) {
+ // TODO: Right now, each split sends along its full data, even if later down
+ // the RDD chain it gets cached. It might be worthwhile to write the data to
+ // a file in the DFS and read it in the split instead.
+
+ val id = ParallelArray.newId()
+
+ @transient val splits_ = {
+ val slices = ParallelArray.slice(data, numSlices).toArray
+ slices.indices.map(i => new ParallelArraySplit(id, i, slices(i))).toArray
+ }
+
+ override def splits = splits_.asInstanceOf[Array[Split]]
+
+ override def iterator(s: Split) = s.asInstanceOf[ParallelArraySplit[T]].iterator
+
+ override def preferredLocations(s: Split): Seq[String] = Nil
+}
+
+private object ParallelArray {
+ val nextId = new AtomicLong(0) // Creates IDs for ParallelArrays (on master)
+ def newId() = nextId.getAndIncrement()
+
+ def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
+ if (numSlices < 1)
+ throw new IllegalArgumentException("Positive number of slices required")
+ seq match {
+ case r: Range.Inclusive => {
+ val sign = if (r.step < 0) -1 else 1
+ slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]],
+ numSlices)
+ }
+ case r: Range => {
+ (0 until numSlices).map(i => {
+ val start = ((i * r.length.toLong) / numSlices).toInt
+ val end = (((i+1) * r.length.toLong) / numSlices).toInt
+ new Range(
+ r.start + start * r.step, r.start + end * r.step, r.step)
+ }).asInstanceOf[Seq[Seq[T]]]
+ }
+ case _ => {
+ val array = seq.toArray // To prevent O(n^2) operations for List etc
+ (0 until numSlices).map(i => {
+ val start = ((i * array.length.toLong) / numSlices).toInt
+ val end = (((i+1) * array.length.toLong) / numSlices).toInt
+ array.slice(start, end).toSeq
+ })
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
new file mode 100644
index 0000000000..bac59319a0
--- /dev/null
+++ b/core/src/main/scala/spark/RDD.scala
@@ -0,0 +1,418 @@
+package spark
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.HashSet
+import java.util.Random
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.Map
+import scala.collection.mutable.HashMap
+
+import SparkContext._
+
+import mesos._
+
+
+@serializable
+abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
+ def splits: Array[Split]
+ def iterator(split: Split): Iterator[T]
+ def preferredLocations(split: Split): Seq[String]
+
+ def taskStarted(split: Split, slot: SlaveOffer) {}
+
+ def sparkContext = sc
+
+ def map[U: ClassManifest](f: T => U) = new MappedRDD(this, sc.clean(f))
+ def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
+ def cache() = new CachedRDD(this)
+
+ def sample(withReplacement: Boolean, frac: Double, seed: Int) =
+ new SampledRDD(this, withReplacement, frac, seed)
+
+ def flatMap[U: ClassManifest](f: T => Traversable[U]) =
+ new FlatMappedRDD(this, sc.clean(f))
+
+ def foreach(f: T => Unit) {
+ val cleanF = sc.clean(f)
+ val tasks = splits.map(s => new ForeachTask(this, s, cleanF)).toArray
+ sc.runTaskObjects(tasks)
+ }
+
+ def collect(): Array[T] = {
+ val tasks = splits.map(s => new CollectTask(this, s))
+ val results = sc.runTaskObjects(tasks)
+ Array.concat(results: _*)
+ }
+
+ def toArray(): Array[T] = collect()
+
+ def reduce(f: (T, T) => T): T = {
+ val cleanF = sc.clean(f)
+ val tasks = splits.map(s => new ReduceTask(this, s, f))
+ val results = new ArrayBuffer[T]
+ for (option <- sc.runTaskObjects(tasks); elem <- option)
+ results += elem
+ if (results.size == 0)
+ throw new UnsupportedOperationException("empty collection")
+ else
+ return results.reduceLeft(f)
+ }
+
+ def take(num: Int): Array[T] = {
+ if (num == 0)
+ return new Array[T](0)
+ val buf = new ArrayBuffer[T]
+ for (split <- splits; elem <- iterator(split)) {
+ buf += elem
+ if (buf.length == num)
+ return buf.toArray
+ }
+ return buf.toArray
+ }
+
+ def first: T = take(1) match {
+ case Array(t) => t
+ case _ => throw new UnsupportedOperationException("empty collection")
+ }
+
+ def count(): Long = {
+ try {
+ map(x => 1L).reduce(_+_)
+ } catch {
+ case e: UnsupportedOperationException => 0L // No elements in RDD
+ }
+ }
+
+ def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other))
+
+ def ++(other: RDD[T]) = this.union(other)
+
+ def splitRdd() = new SplitRDD(this)
+
+ def cartesian[U: ClassManifest](other: RDD[U]) =
+ new CartesianRDD(sc, this, other)
+
+ def groupBy[K](func: T => K, numSplits: Int): RDD[(K, Seq[T])] =
+ this.map(t => (func(t), t)).groupByKey(numSplits)
+
+ def groupBy[K](func: T => K): RDD[(K, Seq[T])] =
+ groupBy[K](func, sc.numCores)
+}
+
+@serializable
+abstract class RDDTask[U: ClassManifest, T: ClassManifest](
+ val rdd: RDD[T], val split: Split)
+extends Task[U] {
+ override def preferredLocations() = rdd.preferredLocations(split)
+ override def markStarted(slot: SlaveOffer) { rdd.taskStarted(split, slot) }
+}
+
+class ForeachTask[T: ClassManifest](
+ rdd: RDD[T], split: Split, func: T => Unit)
+extends RDDTask[Unit, T](rdd, split) with Logging {
+ override def run() {
+ logInfo("Processing " + split)
+ rdd.iterator(split).foreach(func)
+ }
+}
+
+class CollectTask[T](
+ rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
+extends RDDTask[Array[T], T](rdd, split) with Logging {
+ override def run(): Array[T] = {
+ logInfo("Processing " + split)
+ rdd.iterator(split).toArray(m)
+ }
+}
+
+class ReduceTask[T: ClassManifest](
+ rdd: RDD[T], split: Split, f: (T, T) => T)
+extends RDDTask[Option[T], T](rdd, split) with Logging {
+ override def run(): Option[T] = {
+ logInfo("Processing " + split)
+ val iter = rdd.iterator(split)
+ if (iter.hasNext)
+ Some(iter.reduceLeft(f))
+ else
+ None
+ }
+}
+
+class MappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T], f: T => U)
+extends RDD[U](prev.sparkContext) {
+ override def splits = prev.splits
+ override def preferredLocations(split: Split) = prev.preferredLocations(split)
+ override def iterator(split: Split) = prev.iterator(split).map(f)
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+class FilteredRDD[T: ClassManifest](
+ prev: RDD[T], f: T => Boolean)
+extends RDD[T](prev.sparkContext) {
+ override def splits = prev.splits
+ override def preferredLocations(split: Split) = prev.preferredLocations(split)
+ override def iterator(split: Split) = prev.iterator(split).filter(f)
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T], f: T => Traversable[U])
+extends RDD[U](prev.sparkContext) {
+ override def splits = prev.splits
+ override def preferredLocations(split: Split) = prev.preferredLocations(split)
+ override def iterator(split: Split) =
+ prev.iterator(split).toStream.flatMap(f).iterator
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+class SplitRDD[T: ClassManifest](prev: RDD[T])
+extends RDD[Array[T]](prev.sparkContext) {
+ override def splits = prev.splits
+ override def preferredLocations(split: Split) = prev.preferredLocations(split)
+ override def iterator(split: Split) = Iterator.fromArray(Array(prev.iterator(split).toArray))
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
+
+@serializable class SeededSplit(val prev: Split, val seed: Int) extends Split {
+ override def getId() =
+ "SeededSplit(" + prev.getId() + ", seed " + seed + ")"
+}
+
+class SampledRDD[T: ClassManifest](
+ prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int)
+extends RDD[T](prev.sparkContext) {
+
+ @transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) }
+
+ override def splits = splits_.asInstanceOf[Array[Split]]
+
+ override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SeededSplit].prev)
+
+ override def iterator(splitIn: Split) = {
+ val split = splitIn.asInstanceOf[SeededSplit]
+ val rg = new Random(split.seed);
+ // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
+ if (withReplacement) {
+ val oldData = prev.iterator(split.prev).toArray
+ val sampleSize = (oldData.size * frac).ceil.toInt
+ val sampledData = for (i <- 1 to sampleSize) yield oldData(rg.nextInt(oldData.size)) // all of oldData's indices are candidates, even if sampleSize < oldData.size
+ sampledData.iterator
+ }
+ // Sampling without replacement
+ else {
+ prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac))
+ }
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split.asInstanceOf[SeededSplit].prev, slot)
+}
+
+
+class CachedRDD[T](
+ prev: RDD[T])(implicit m: ClassManifest[T])
+extends RDD[T](prev.sparkContext) with Logging {
+ val id = CachedRDD.newId()
+ @transient val cacheLocs = Map[Split, List[String]]()
+
+ override def splits = prev.splits
+
+ override def preferredLocations(split: Split) = {
+ if (cacheLocs.contains(split))
+ cacheLocs(split)
+ else
+ prev.preferredLocations(split)
+ }
+
+ override def iterator(split: Split): Iterator[T] = {
+ val key = id + "::" + split.getId()
+ logInfo("CachedRDD split key is " + key)
+ val cache = CachedRDD.cache
+ val loading = CachedRDD.loading
+ val cachedVal = cache.get(key)
+ if (cachedVal != null) {
+ // Split is in cache, so just return its values
+ return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
+ } else {
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ return Iterator.fromArray(cache.get(key).asInstanceOf[Array[T]])
+ } else {
+ loading.add(key)
+ }
+ }
+ // If we got here, we have to load the split
+ logInfo("Loading and caching " + split)
+ val array = prev.iterator(split).toArray(m)
+ cache.put(key, array)
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ return Iterator.fromArray(array)
+ }
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) {
+ val oldList = cacheLocs.getOrElse(split, Nil)
+ val host = slot.getHost
+ if (!oldList.contains(host))
+ cacheLocs(split) = host :: oldList
+ }
+}
+
+private object CachedRDD {
+ val nextId = new AtomicLong(0) // Generates IDs for cached RDDs (on master)
+ def newId() = nextId.getAndIncrement()
+
+ // Stores map results for various splits locally (on workers)
+ val cache = Cache.newKeySpace()
+
+ // Remembers which splits are currently being loaded (on workers)
+ val loading = new HashSet[String]
+}
+
+@serializable
+class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split)
+extends Split {
+ def iterator() = rdd.iterator(split)
+ def preferredLocations() = rdd.preferredLocations(split)
+ override def getId() = "UnionSplit(" + split.getId() + ")"
+}
+
+@serializable
+class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]])
+extends RDD[T](sc) {
+ @transient val splits_ : Array[Split] = {
+ val splits: Seq[Split] =
+ for (rdd <- rdds; split <- rdd.splits)
+ yield new UnionSplit(rdd, split)
+ splits.toArray
+ }
+
+ override def splits = splits_
+
+ override def iterator(s: Split): Iterator[T] =
+ s.asInstanceOf[UnionSplit[T]].iterator()
+
+ override def preferredLocations(s: Split): Seq[String] =
+ s.asInstanceOf[UnionSplit[T]].preferredLocations()
+}
+
+@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split {
+ override def getId() =
+ "CartesianSplit(" + s1.getId() + ", " + s2.getId() + ")"
+}
+
+@serializable
+class CartesianRDD[T: ClassManifest, U:ClassManifest](
+ sc: SparkContext, rdd1: RDD[T], rdd2: RDD[U])
+extends RDD[Pair[T, U]](sc) {
+ @transient val splits_ = {
+ // create the cross product split
+ rdd2.splits.map(y => rdd1.splits.map(x => new CartesianSplit(x, y))).flatten
+ }
+
+ override def splits = splits_.asInstanceOf[Array[Split]]
+
+ override def preferredLocations(split: Split) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
+ }
+
+ override def iterator(split: Split) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) = {
+ val currSplit = split.asInstanceOf[CartesianSplit]
+ rdd1.taskStarted(currSplit.s1, slot)
+ rdd2.taskStarted(currSplit.s2, slot)
+ }
+}
+
+@serializable class PairRDDExtras[K, V](self: RDD[(K, V)]) {
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
+ def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
+ for ((k, v) <- m2) {
+ m1.get(k) match {
+ case None => m1(k) = v
+ case Some(w) => m1(k) = func(w, v)
+ }
+ }
+ return m1
+ }
+ self.map(pair => HashMap(pair)).reduce(mergeMaps)
+ }
+
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ numSplits: Int)
+ : RDD[(K, C)] =
+ {
+ val shufClass = Class.forName(System.getProperty(
+ "spark.shuffle.class", "spark.DfsShuffle"))
+ val shuf = shufClass.newInstance().asInstanceOf[Shuffle[K, V, C]]
+ shuf.compute(self, numSplits, createCombiner, mergeValue, mergeCombiners)
+ }
+
+ def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
+ combineByKey[V]((v: V) => v, func, func, numSplits)
+ }
+
+ def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
+ def createCombiner(v: V) = ArrayBuffer(v)
+ def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
+ val bufs = combineByKey[ArrayBuffer[V]](
+ createCombiner _, mergeValue _, mergeCombiners _, numSplits)
+ bufs.asInstanceOf[RDD[(K, Seq[V])]]
+ }
+
+ def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
+ val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
+ val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
+ (vs ++ ws).groupByKey(numSplits).flatMap {
+ case (k, seq) => {
+ val vbuf = new ArrayBuffer[V]
+ val wbuf = new ArrayBuffer[W]
+ seq.foreach(_ match {
+ case Left(v) => vbuf += v
+ case Right(w) => wbuf += w
+ })
+ for (v <- vbuf; w <- wbuf) yield (k, (v, w))
+ }
+ }
+ }
+
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, numCores)
+ }
+
+ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
+ reduceByKey(func, numCores)
+ }
+
+ def groupByKey(): RDD[(K, Seq[V])] = {
+ groupByKey(numCores)
+ }
+
+ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
+ join(other, numCores)
+ }
+
+ def numCores = self.sparkContext.numCores
+
+ def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
+}
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
new file mode 100644
index 0000000000..b9f3128c82
--- /dev/null
+++ b/core/src/main/scala/spark/Scheduler.scala
@@ -0,0 +1,10 @@
+package spark
+
+// Scheduler trait, implemented by both NexusScheduler and LocalScheduler.
+private trait Scheduler {
+ def start()
+ def waitForRegister()
+ def runTasks[T](tasks: Array[Task[T]])(implicit m: ClassManifest[T]): Array[T]
+ def stop()
+ def numCores(): Int
+}
diff --git a/core/src/main/scala/spark/SerializableWritable.scala b/core/src/main/scala/spark/SerializableWritable.scala
new file mode 100644
index 0000000000..ae393d06d3
--- /dev/null
+++ b/core/src/main/scala/spark/SerializableWritable.scala
@@ -0,0 +1,26 @@
+package spark
+
+import java.io._
+
+import org.apache.hadoop.io.ObjectWritable
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred.JobConf
+
+@serializable
+class SerializableWritable[T <: Writable](@transient var t: T) {
+ def value = t
+ override def toString = t.toString
+
+ private def writeObject(out: ObjectOutputStream) {
+ out.defaultWriteObject()
+ new ObjectWritable(t).write(out)
+ }
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ val ow = new ObjectWritable()
+ ow.setConf(new JobConf())
+ ow.readFields(in)
+ t = ow.get().asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/spark/Shuffle.scala b/core/src/main/scala/spark/Shuffle.scala
new file mode 100644
index 0000000000..4c5649b537
--- /dev/null
+++ b/core/src/main/scala/spark/Shuffle.scala
@@ -0,0 +1,15 @@
+package spark
+
+/**
+ * A trait for shuffle system. Given an input RDD and combiner functions
+ * for PairRDDExtras.combineByKey(), returns an output RDD.
+ */
+@serializable
+trait Shuffle[K, V, C] {
+ def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)]
+}
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala
new file mode 100644
index 0000000000..09846ccc34
--- /dev/null
+++ b/core/src/main/scala/spark/SimpleJob.scala
@@ -0,0 +1,272 @@
+package spark
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import mesos._
+
+
+/**
+ * A Job that runs a set of tasks with no interdependencies.
+ */
+class SimpleJob[T: ClassManifest](
+ sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int)
+extends Job(jobId) with Logging
+{
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ // CPUs and memory to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+ val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = 4
+
+ val callingThread = currentThread
+ val numTasks = tasks.length
+ val results = new Array[T](numTasks)
+ val launched = new Array[Boolean](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val tidToIndex = HashMap[Int, Int]()
+
+ var allFinished = false
+ val joinLock = new Object() // Used to wait for all tasks to finish
+
+ var tasksLaunched = 0
+ var tasksFinished = 0
+
+ // Last time when we launched a preferred task (for delay scheduling)
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ // List of pending tasks for each node. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List containing pending tasks with no locality preferences
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // List containing all pending tasks (also used as a stack, as above)
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Did the job fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Add a task to all the pending-task lists that it should be on.
+ def addPendingTask(index: Int) {
+ val locations = tasks(index).preferredLocations
+ if (locations.size == 0) {
+ pendingTasksWithNoPrefs += index
+ } else {
+ for (host <- locations) {
+ val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ list += index
+ }
+ }
+ allPendingTasks += index
+ }
+
+ // Mark the job as finished and wake up any threads waiting on it
+ def setAllFinished() {
+ joinLock.synchronized {
+ allFinished = true
+ joinLock.notifyAll()
+ }
+ }
+
+ // Wait until the job finishes and return its results
+ def join(): Array[T] = {
+ joinLock.synchronized {
+ while (!allFinished) {
+ joinLock.wait()
+ }
+ if (failed) {
+ throw new SparkException(causeOfFailure)
+ } else {
+ return results
+ }
+ }
+ }
+
+ // Return the pending tasks list for a given host, or an empty list if
+ // there is no map entry for that host
+ def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Dequeue a pending task from the given list and return its index.
+ // Return None if the list is empty.
+ // This method also cleans up any tasks in the list that have already
+ // been launched, since we want that to happen lazily.
+ def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (!launched(index) && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ // Dequeue a pending task for a given node and return its index.
+ // If localOnly is set to false, allow non-local tasks as well.
+ def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ val localTask = findTaskFromList(getPendingTasksForHost(host))
+ if (localTask != None) {
+ return localTask
+ }
+ val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
+ if (noPrefTask != None) {
+ return noPrefTask
+ }
+ if (!localOnly) {
+ return findTaskFromList(allPendingTasks) // Look for non-local task
+ } else {
+ return None
+ }
+ }
+
+ // Does a host count as a preferred location for a task? This is true if
+ // either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ def isPreferredLocation(task: Task[T], host: String): Boolean = {
+ val locs = task.preferredLocations
+ return (locs.contains(host) || locs.isEmpty)
+ }
+
+ // Respond to an offer of a single slave from the scheduler by finding a task
+ def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int)
+ : Option[TaskDescription] = {
+ if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK &&
+ availableMem >= MEM_PER_TASK) {
+ val time = System.currentTimeMillis
+ val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+ val host = offer.getHost
+ findTask(host, localOnly) match {
+ case Some(index) => {
+ // Found a task; do some bookkeeping and return a Mesos task for it
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ val preferred = isPreferredLocation(task, host)
+ val prefStr = if(preferred) "preferred" else "non-preferred"
+ val message =
+ "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
+ jobId, index, taskId, offer.getSlaveId, host, prefStr)
+ logInfo(message)
+ // Do various bookkeeping
+ tidToIndex(taskId) = index
+ task.markStarted(offer)
+ launched(index) = true
+ tasksLaunched += 1
+ if (preferred)
+ lastPreferredLaunchTime = time
+ // Create and return the Mesos task object
+ val params = new JHashMap[String, String]
+ params.put("cpus", CPUS_PER_TASK.toString)
+ params.put("mem", MEM_PER_TASK.toString)
+ val serializedTask = Utils.serialize(task)
+ logDebug("Serialized size: " + serializedTask.size)
+ val taskName = "task %d:%d".format(jobId, index)
+ return Some(new TaskDescription(
+ taskId, offer.getSlaveId, taskName, params, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(status: TaskStatus) {
+ status.getState match {
+ case TaskState.TASK_FINISHED =>
+ taskFinished(status)
+ case TaskState.TASK_LOST =>
+ taskLost(status)
+ case TaskState.TASK_FAILED =>
+ taskLost(status)
+ case TaskState.TASK_KILLED =>
+ taskLost(status)
+ case _ =>
+ }
+ }
+
+ def taskFinished(status: TaskStatus) {
+ val tid = status.getTaskId
+ val index = tidToIndex(tid)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %d (progress: %d/%d)".format(
+ tid, tasksFinished, numTasks))
+ // Deserialize task result
+ val result = Utils.deserialize[TaskResult[T]](status.getData)
+ results(index) = result.value
+ // Update accumulators
+ Accumulators.add(callingThread, result.accumUpdates)
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks)
+ setAllFinished()
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(status: TaskStatus) {
+ val tid = status.getTaskId
+ val index = tidToIndex(tid)
+ if (!finished(index)) {
+ logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index))
+ launched(index) = false
+ tasksLaunched -= 1
+ // Re-enqueue the task as pending
+ addPendingTask(index)
+ // Mark it as failed
+ if (status.getState == TaskState.TASK_FAILED ||
+ status.getState == TaskState.TASK_LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %d:%d failed more than %d times; aborting job".format(
+ jobId, index, MAX_TASK_FAILURES))
+ abort("Task %d failed more than %d times".format(
+ index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(code: Int, message: String) {
+ // Save the error message
+ abort("Mesos error: %s (error code: %d)".format(message, code))
+ }
+
+ def abort(message: String) {
+ joinLock.synchronized {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ // Indicate to any joining thread that we're done
+ setAllFinished()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
new file mode 100644
index 0000000000..12dd19d704
--- /dev/null
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -0,0 +1,160 @@
+package spark
+
+import java.lang.reflect.Field
+import java.lang.reflect.Modifier
+import java.lang.reflect.{Array => JArray}
+import java.util.IdentityHashMap
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * Estimates the sizes of Java objects (number of bytes of memory they occupy),
+ * for use in memory-aware caches.
+ *
+ * Based on the following JavaWorld article:
+ * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
+ */
+object SizeEstimator {
+ private val OBJECT_SIZE = 8 // Minimum size of a java.lang.Object
+ private val POINTER_SIZE = 4 // Size of an object reference
+
+ // Sizes of primitive types
+ private val BYTE_SIZE = 1
+ private val BOOLEAN_SIZE = 1
+ private val CHAR_SIZE = 2
+ private val SHORT_SIZE = 2
+ private val INT_SIZE = 4
+ private val LONG_SIZE = 8
+ private val FLOAT_SIZE = 4
+ private val DOUBLE_SIZE = 8
+
+ // A cache of ClassInfo objects for each class
+ private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo]
+ classInfos.put(classOf[Object], new ClassInfo(OBJECT_SIZE, Nil))
+
+ /**
+ * The state of an ongoing size estimation. Contains a stack of objects
+ * to visit as well as an IdentityHashMap of visited objects, and provides
+ * utility methods for enqueueing new objects to visit.
+ */
+ private class SearchState {
+ val visited = new IdentityHashMap[AnyRef, AnyRef]
+ val stack = new ArrayBuffer[AnyRef]
+ var size = 0L
+
+ def enqueue(obj: AnyRef) {
+ if (obj != null && !visited.containsKey(obj)) {
+ visited.put(obj, null)
+ stack += obj
+ }
+ }
+
+ def isFinished(): Boolean = stack.isEmpty
+
+ def dequeue(): AnyRef = {
+ val elem = stack.last
+ stack.trimEnd(1)
+ return elem
+ }
+ }
+
+ /**
+ * Cached information about each class. We remember two things: the
+ * "shell size" of the class (size of all non-static fields plus the
+ * java.lang.Object size), and any fields that are pointers to objects.
+ */
+ private class ClassInfo(
+ val shellSize: Long,
+ val pointerFields: List[Field]) {}
+
+ def estimate(obj: AnyRef): Long = {
+ val state = new SearchState
+ state.enqueue(obj)
+ while (!state.isFinished) {
+ visitSingleObject(state.dequeue(), state)
+ }
+ return state.size
+ }
+
+ private def visitSingleObject(obj: AnyRef, state: SearchState) {
+ val cls = obj.getClass
+ if (cls.isArray) {
+ visitArray(obj, cls, state)
+ } else {
+ val classInfo = getClassInfo(cls)
+ state.size += classInfo.shellSize
+ for (field <- classInfo.pointerFields) {
+ state.enqueue(field.get(obj))
+ }
+ }
+ }
+
+ private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
+ val length = JArray.getLength(array)
+ val elementClass = cls.getComponentType
+ if (elementClass.isPrimitive) {
+ state.size += length * primitiveSize(elementClass)
+ } else {
+ state.size += length * POINTER_SIZE
+ for (i <- 0 until length) {
+ state.enqueue(JArray.get(array, i))
+ }
+ }
+ }
+
+ private def primitiveSize(cls: Class[_]): Long = {
+ if (cls == classOf[Byte])
+ BYTE_SIZE
+ else if (cls == classOf[Boolean])
+ BOOLEAN_SIZE
+ else if (cls == classOf[Char])
+ CHAR_SIZE
+ else if (cls == classOf[Short])
+ SHORT_SIZE
+ else if (cls == classOf[Int])
+ INT_SIZE
+ else if (cls == classOf[Long])
+ LONG_SIZE
+ else if (cls == classOf[Float])
+ FLOAT_SIZE
+ else if (cls == classOf[Double])
+ DOUBLE_SIZE
+ else throw new IllegalArgumentException(
+ "Non-primitive class " + cls + " passed to primitiveSize()")
+ }
+
+ /**
+ * Get or compute the ClassInfo for a given class.
+ */
+ private def getClassInfo(cls: Class[_]): ClassInfo = {
+ // Check whether we've already cached a ClassInfo for this class
+ val info = classInfos.get(cls)
+ if (info != null) {
+ return info
+ }
+
+ val parent = getClassInfo(cls.getSuperclass)
+ var shellSize = parent.shellSize
+ var pointerFields = parent.pointerFields
+
+ for (field <- cls.getDeclaredFields) {
+ if (!Modifier.isStatic(field.getModifiers)) {
+ val fieldClass = field.getType
+ if (fieldClass.isPrimitive) {
+ shellSize += primitiveSize(fieldClass)
+ } else {
+ field.setAccessible(true) // Enable future get()'s on this field
+ shellSize += POINTER_SIZE
+ pointerFields = field :: pointerFields
+ }
+ }
+ }
+
+ // Create and cache a new ClassInfo
+ val newInfo = new ClassInfo(shellSize, pointerFields)
+ classInfos.put(cls, newInfo)
+ return newInfo
+ }
+}
diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala
new file mode 100644
index 0000000000..e84aa57efa
--- /dev/null
+++ b/core/src/main/scala/spark/SoftReferenceCache.scala
@@ -0,0 +1,13 @@
+package spark
+
+import com.google.common.collect.MapMaker
+
+/**
+ * An implementation of Cache that uses soft references.
+ */
+class SoftReferenceCache extends Cache {
+ val map = new MapMaker().softValues().makeMap[Any, Any]()
+
+ override def get(key: Any): Any = map.get(key)
+ override def put(key: Any, value: Any) = map.put(key, value)
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
new file mode 100644
index 0000000000..02e80c7756
--- /dev/null
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -0,0 +1,175 @@
+package spark
+
+import java.io._
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.SequenceFileInputFormat
+
+
+class SparkContext(
+ master: String,
+ frameworkName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil)
+extends Logging {
+ private var scheduler: Scheduler = {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ master match {
+ case "local" =>
+ new LocalScheduler(1)
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt)
+ case _ =>
+ System.loadLibrary("mesos")
+ new MesosScheduler(this, master, frameworkName)
+ }
+ }
+
+ private val isLocal = scheduler.isInstanceOf[LocalScheduler]
+
+ // Start the scheduler, the cache and the broadcast system
+ scheduler.start()
+ Cache.initialize()
+ Broadcast.initialize(true)
+
+ // Methods for creating RDDs
+
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] =
+ new ParallelArray[T](this, seq, numSlices)
+
+ def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] =
+ parallelize(seq, numCores)
+
+ def textFile(path: String): RDD[String] =
+ new HadoopTextFile(this, path)
+
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V](path: String,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V])
+ : RDD[(K, V)] = {
+ new HadoopFile(this, path, inputFormatClass, keyClass, valueClass)
+ }
+
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out
+ * the classes of keys, values and the InputFormat so that users don't need
+ * to pass them directly.
+ */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ : RDD[(K, V)] = {
+ hadoopFile(path,
+ fm.erasure.asInstanceOf[Class[F]],
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]])
+ }
+
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types */
+ def sequenceFile[K, V](path: String,
+ keyClass: Class[K],
+ valueClass: Class[V]): RDD[(K, V)] = {
+ val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
+ hadoopFile(path, inputFormatClass, keyClass, valueClass)
+ }
+
+ /**
+ * Smarter version of sequenceFile() that obtains the key and value classes
+ * from ClassManifests instead of requiring the user to pass them directly.
+ */
+ def sequenceFile[K, V](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V]): RDD[(K, V)] = {
+ sequenceFile(path,
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]])
+ }
+
+ /** Build the union of a list of RDDs. */
+ def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] =
+ new UnionRDD(this, rdds)
+
+ // Methods for creating shared variables
+
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
+ new Accumulator(initialValue, param)
+
+ // TODO: Keep around a weak hash map of values to Cached versions?
+ def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, isLocal)
+ //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, isLocal)
+
+ // Stop the SparkContext
+ def stop() {
+ scheduler.stop()
+ scheduler = null
+ }
+
+ // Wait for the scheduler to be registered
+ def waitForRegister() {
+ scheduler.waitForRegister()
+ }
+
+ // Get Spark's home location from either a value set through the constructor,
+ // or the spark.home Java property, or the SPARK_HOME environment variable
+ // (in that order of preference). If neither of these is set, return None.
+ def getSparkHome(): Option[String] = {
+ if (sparkHome != null)
+ Some(sparkHome)
+ else if (System.getProperty("spark.home") != null)
+ Some(System.getProperty("spark.home"))
+ else if (System.getenv("SPARK_HOME") != null)
+ Some(System.getenv("SPARK_HOME"))
+ else
+ None
+ }
+
+ // Submit an array of tasks (passed as functions) to the scheduler
+ def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = {
+ runTaskObjects(tasks.map(f => new FunctionTask(f)))
+ }
+
+ // Run an array of spark.Task objects
+ private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]])
+ : Array[T] = {
+ logInfo("Running " + tasks.length + " tasks in parallel")
+ val start = System.nanoTime
+ val result = scheduler.runTasks(tasks.toArray)
+ logInfo("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
+ return result
+ }
+
+ // Clean a closure to make it ready to serialized and send to tasks
+ // (removes unreferenced variables in $outer's, updates REPL variables)
+ private[spark] def clean[F <: AnyRef](f: F): F = {
+ ClosureCleaner.clean(f)
+ return f
+ }
+
+ // Get the number of cores available to run tasks (as reported by Scheduler)
+ def numCores = scheduler.numCores
+}
+
+
+/**
+ * The SparkContext object contains a number of implicit conversions and
+ * parameters for use with various Spark features.
+ */
+object SparkContext {
+ implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ def addInPlace(t1: Double, t2: Double): Double = t1 + t2
+ def zero(initialValue: Double) = 0.0
+ }
+
+ implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ def addInPlace(t1: Int, t2: Int): Int = t1 + t2
+ def zero(initialValue: Int) = 0
+ }
+
+ // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
+ implicit def rddToPairRDDExtras[K, V](rdd: RDD[(K, V)]) =
+ new PairRDDExtras(rdd)
+}
diff --git a/core/src/main/scala/spark/SparkException.scala b/core/src/main/scala/spark/SparkException.scala
new file mode 100644
index 0000000000..6f9be1a94f
--- /dev/null
+++ b/core/src/main/scala/spark/SparkException.scala
@@ -0,0 +1,3 @@
+package spark
+
+class SparkException(message: String) extends Exception(message) {}
diff --git a/core/src/main/scala/spark/Split.scala b/core/src/main/scala/spark/Split.scala
new file mode 100644
index 0000000000..116cd16370
--- /dev/null
+++ b/core/src/main/scala/spark/Split.scala
@@ -0,0 +1,13 @@
+package spark
+
+/**
+ * A partition of an RDD.
+ */
+@serializable trait Split {
+ /**
+ * Get a unique ID for this split which can be used, for example, to
+ * set up caches based on it. The ID should stay the same if we serialize
+ * and then deserialize the split.
+ */
+ def getId(): String
+}
diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala
new file mode 100644
index 0000000000..6e94009f6e
--- /dev/null
+++ b/core/src/main/scala/spark/Task.scala
@@ -0,0 +1,16 @@
+package spark
+
+import mesos._
+
+@serializable
+trait Task[T] {
+ def run: T
+ def preferredLocations: Seq[String] = Nil
+ def markStarted(offer: SlaveOffer) {}
+}
+
+@serializable
+class FunctionTask[T](body: () => T)
+extends Task[T] {
+ def run: T = body()
+}
diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala
new file mode 100644
index 0000000000..db33c9ff44
--- /dev/null
+++ b/core/src/main/scala/spark/TaskResult.scala
@@ -0,0 +1,9 @@
+package spark
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+@serializable
+private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any])
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
new file mode 100644
index 0000000000..e333dd9c91
--- /dev/null
+++ b/core/src/main/scala/spark/Utils.scala
@@ -0,0 +1,127 @@
+package spark
+
+import java.io._
+import java.net.InetAddress
+import java.util.UUID
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+/**
+ * Various utility methods used by Spark.
+ */
+object Utils {
+ def serialize[T](o: T): Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ val oos = new ObjectOutputStream(bos)
+ oos.writeObject(o)
+ oos.close
+ return bos.toByteArray
+ }
+
+ def deserialize[T](bytes: Array[Byte]): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis)
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ def isAlpha(c: Char) = {
+ (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
+ }
+
+ def splitWords(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var i = 0
+ while (i < s.length) {
+ var j = i
+ while (j < s.length && isAlpha(s.charAt(j))) {
+ j += 1
+ }
+ if (j > i) {
+ buf += s.substring(i, j);
+ }
+ i = j
+ while (i < s.length && !isAlpha(s.charAt(i))) {
+ i += 1
+ }
+ }
+ return buf
+ }
+
+ // Create a temporary directory inside the given parent directory
+ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File =
+ {
+ var attempts = 0
+ val maxAttempts = 10
+ var dir: File = null
+ while (dir == null) {
+ attempts += 1
+ if (attempts > maxAttempts) {
+ throw new IOException("Failed to create a temp directory " +
+ "after " + maxAttempts + " attempts!")
+ }
+ try {
+ dir = new File(root, "spark-" + UUID.randomUUID.toString)
+ if (dir.exists() || !dir.mkdirs()) {
+ dir = null
+ }
+ } catch { case e: IOException => ; }
+ }
+ return dir
+ }
+
+ // Copy all data from an InputStream to an OutputStream
+ def copyStream(in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false)
+ {
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = in.read(buf)
+ if (n != -1) {
+ out.write(buf, 0, n)
+ }
+ }
+ if (closeStreams) {
+ in.close()
+ out.close()
+ }
+ }
+
+ // Shuffle the elements of a collection into a random order, returning the
+ // result in a new collection. Unlike scala.util.Random.shuffle, this method
+ // uses a local random number generator, avoiding inter-thread contention.
+ def shuffle[T](seq: TraversableOnce[T]): Seq[T] = {
+ val buf = new ArrayBuffer[T]()
+ buf ++= seq
+ val rand = new Random()
+ for (i <- (buf.size - 1) to 1 by -1) {
+ val j = rand.nextInt(i)
+ val tmp = buf(j)
+ buf(j) = buf(i)
+ buf(i) = tmp
+ }
+ buf
+ }
+
+ /**
+ * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
+ */
+ def localIpAddress(): String = {
+ // Get local IP as an array of four bytes
+ val bytes = InetAddress.getLocalHost().getAddress()
+ // Convert the bytes to ints (keeping in mind that they may be negative)
+ // and join them into a string
+ return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
+ }
+}
diff --git a/core/src/main/scala/spark/WeakReferenceCache.scala b/core/src/main/scala/spark/WeakReferenceCache.scala
new file mode 100644
index 0000000000..ddca065454
--- /dev/null
+++ b/core/src/main/scala/spark/WeakReferenceCache.scala
@@ -0,0 +1,14 @@
+package spark
+
+import com.google.common.collect.MapMaker
+
+/**
+ * An implementation of Cache that uses weak references.
+ */
+class WeakReferenceCache extends Cache {
+ val map = new MapMaker().weakValues().makeMap[Any, Any]()
+
+ override def get(key: Any): Any = map.get(key)
+ override def put(key: Any, value: Any) = map.put(key, value)
+}
+
diff --git a/core/src/main/scala/spark/repl/ExecutorClassLoader.scala b/core/src/main/scala/spark/repl/ExecutorClassLoader.scala
new file mode 100644
index 0000000000..13d81ec1cf
--- /dev/null
+++ b/core/src/main/scala/spark/repl/ExecutorClassLoader.scala
@@ -0,0 +1,108 @@
+package spark.repl
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.net.{URI, URL, URLClassLoader, URLEncoder}
+import java.util.concurrent.{Executors, ExecutorService}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.objectweb.asm._
+import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.Opcodes._
+
+
+/**
+ * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI,
+ * used to load classes defined by the interpreter when the REPL is used
+ */
+class ExecutorClassLoader(classUri: String, parent: ClassLoader)
+extends ClassLoader(parent) {
+ val uri = new URI(classUri)
+ val directory = uri.getPath
+
+ // Hadoop FileSystem object for our URI, if it isn't using HTTP
+ var fileSystem: FileSystem = {
+ if (uri.getScheme() == "http")
+ null
+ else
+ FileSystem.get(uri, new Configuration())
+ }
+
+ override def findClass(name: String): Class[_] = {
+ try {
+ val pathInDirectory = name.replace('.', '/') + ".class"
+ val inputStream = {
+ if (fileSystem != null)
+ fileSystem.open(new Path(directory, pathInDirectory))
+ else
+ new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
+ }
+ val bytes = readAndTransformClass(name, inputStream)
+ inputStream.close()
+ return defineClass(name, bytes, 0, bytes.length)
+ } catch {
+ case e: Exception => throw new ClassNotFoundException(name, e)
+ }
+ }
+
+ def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
+ if (name.startsWith("line") && name.endsWith("$iw$")) {
+ // Class seems to be an interpreter "wrapper" object storing a val or var.
+ // Replace its constructor with a dummy one that does not run the
+ // initialization code placed there by the REPL. The val or var will
+ // be initialized later through reflection when it is used in a task.
+ val cr = new ClassReader(in)
+ val cw = new ClassWriter(
+ ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
+ val cleaner = new ConstructorCleaner(name, cw)
+ cr.accept(cleaner, 0)
+ return cw.toByteArray
+ } else {
+ // Pass the class through unmodified
+ val bos = new ByteArrayOutputStream
+ val bytes = new Array[Byte](4096)
+ var done = false
+ while (!done) {
+ val num = in.read(bytes)
+ if (num >= 0)
+ bos.write(bytes, 0, num)
+ else
+ done = true
+ }
+ return bos.toByteArray
+ }
+ }
+
+ /**
+ * URL-encode a string, preserving only slashes
+ */
+ def urlEncode(str: String): String = {
+ str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/")
+ }
+}
+
+class ConstructorCleaner(className: String, cv: ClassVisitor)
+extends ClassAdapter(cv) {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ val mv = cv.visitMethod(access, name, desc, sig, exceptions)
+ if (name == "<init>" && (access & ACC_STATIC) == 0) {
+ // This is the constructor, time to clean it; just output some new
+ // instructions to mv that create the object and set the static MODULE$
+ // field in the class to point to it, but do nothing otherwise.
+ mv.visitCode()
+ mv.visitVarInsn(ALOAD, 0) // load this
+ mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
+ mv.visitVarInsn(ALOAD, 0) // load this
+ //val classType = className.replace('.', '/')
+ //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
+ mv.visitInsn(RETURN)
+ mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
+ mv.visitEnd()
+ return null
+ } else {
+ return mv
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/repl/Main.scala b/core/src/main/scala/spark/repl/Main.scala
new file mode 100644
index 0000000000..f00df5aa58
--- /dev/null
+++ b/core/src/main/scala/spark/repl/Main.scala
@@ -0,0 +1,16 @@
+package spark.repl
+
+import scala.collection.mutable.Set
+
+object Main {
+ private var _interp: SparkInterpreterLoop = null
+
+ def interp = _interp
+
+ private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i }
+
+ def main(args: Array[String]) {
+ _interp = new SparkInterpreterLoop
+ _interp.main(args)
+ }
+}
diff --git a/core/src/main/scala/spark/repl/SparkCompletion.scala b/core/src/main/scala/spark/repl/SparkCompletion.scala
new file mode 100644
index 0000000000..9fa41736f3
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkCompletion.scala
@@ -0,0 +1,353 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import jline._
+import java.util.{ List => JList }
+import util.returning
+
+object SparkCompletion {
+ def looksLikeInvocation(code: String) = (
+ (code != null)
+ && (code startsWith ".")
+ && !(code == ".")
+ && !(code startsWith "./")
+ && !(code startsWith "..")
+ )
+
+ object Forwarder {
+ def apply(forwardTo: () => Option[CompletionAware]): CompletionAware = new CompletionAware {
+ def completions(verbosity: Int) = forwardTo() map (_ completions verbosity) getOrElse Nil
+ override def follow(s: String) = forwardTo() flatMap (_ follow s)
+ }
+ }
+}
+import SparkCompletion._
+
+// REPL completor - queries supplied interpreter for valid
+// completions based on current contents of buffer.
+class SparkCompletion(val repl: SparkInterpreter) extends SparkCompletionOutput {
+ // verbosity goes up with consecutive tabs
+ private var verbosity: Int = 0
+ def resetVerbosity() = verbosity = 0
+
+ def isCompletionDebug = repl.isCompletionDebug
+ def DBG(msg: => Any) = if (isCompletionDebug) println(msg.toString)
+ def debugging[T](msg: String): T => T = (res: T) => returning[T](res)(x => DBG(msg + x))
+
+ lazy val global: repl.compiler.type = repl.compiler
+ import global._
+ import definitions.{ PredefModule, RootClass, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage }
+
+ // XXX not yet used.
+ lazy val dottedPaths = {
+ def walk(tp: Type): scala.List[Symbol] = {
+ val pkgs = tp.nonPrivateMembers filter (_.isPackage)
+ pkgs ++ (pkgs map (_.tpe) flatMap walk)
+ }
+ walk(RootClass.tpe)
+ }
+
+ def getType(name: String, isModule: Boolean) = {
+ val f = if (isModule) definitions.getModule(_: Name) else definitions.getClass(_: Name)
+ try Some(f(name).tpe)
+ catch { case _: MissingRequirementError => None }
+ }
+
+ def typeOf(name: String) = getType(name, false)
+ def moduleOf(name: String) = getType(name, true)
+
+ trait CompilerCompletion {
+ def tp: Type
+ def effectiveTp = tp match {
+ case MethodType(Nil, resType) => resType
+ case PolyType(Nil, resType) => resType
+ case _ => tp
+ }
+
+ // for some reason any's members don't show up in subclasses, which
+ // we need so 5.<tab> offers asInstanceOf etc.
+ private def anyMembers = AnyClass.tpe.nonPrivateMembers
+ def anyRefMethodsToShow = List("isInstanceOf", "asInstanceOf", "toString")
+
+ def tos(sym: Symbol) = sym.name.decode.toString
+ def memberNamed(s: String) = members find (x => tos(x) == s)
+ def hasMethod(s: String) = methods exists (x => tos(x) == s)
+
+ // XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the
+ // compiler to crash for reasons not yet known.
+ def members = (effectiveTp.nonPrivateMembers ++ anyMembers) filter (_.isPublic)
+ def methods = members filter (_.isMethod)
+ def packages = members filter (_.isPackage)
+ def aliases = members filter (_.isAliasType)
+
+ def memberNames = members map tos
+ def methodNames = methods map tos
+ def packageNames = packages map tos
+ def aliasNames = aliases map tos
+ }
+
+ object TypeMemberCompletion {
+ def apply(tp: Type): TypeMemberCompletion = {
+ if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp)
+ else new TypeMemberCompletion(tp)
+ }
+ def imported(tp: Type) = new ImportCompletion(tp)
+ }
+
+ class TypeMemberCompletion(val tp: Type) extends CompletionAware with CompilerCompletion {
+ def excludeEndsWith: List[String] = Nil
+ def excludeStartsWith: List[String] = List("<") // <byname>, <repeated>, etc.
+ def excludeNames: List[String] = anyref.methodNames -- anyRefMethodsToShow ++ List("_root_")
+
+ def methodSignatureString(sym: Symbol) = {
+ def asString = new MethodSymbolOutput(sym).methodString()
+
+ if (isCompletionDebug)
+ repl.power.showAtAllPhases(asString)
+
+ atPhase(currentRun.typerPhase)(asString)
+ }
+
+ def exclude(name: String): Boolean = (
+ (name contains "$") ||
+ (excludeNames contains name) ||
+ (excludeEndsWith exists (name endsWith _)) ||
+ (excludeStartsWith exists (name startsWith _))
+ )
+ def filtered(xs: List[String]) = xs filterNot exclude distinct
+
+ def completions(verbosity: Int) =
+ debugging(tp + " completions ==> ")(filtered(memberNames))
+
+ override def follow(s: String): Option[CompletionAware] =
+ debugging(tp + " -> '" + s + "' ==> ")(memberNamed(s) map (x => TypeMemberCompletion(x.tpe)))
+
+ override def alternativesFor(id: String): List[String] =
+ debugging(id + " alternatives ==> ") {
+ val alts = members filter (x => x.isMethod && tos(x) == id) map methodSignatureString
+
+ if (alts.nonEmpty) "" :: alts else Nil
+ }
+
+ override def toString = "TypeMemberCompletion(%s)".format(tp)
+ }
+
+ class PackageCompletion(tp: Type) extends TypeMemberCompletion(tp) {
+ override def excludeNames = anyref.methodNames
+ }
+
+ class LiteralCompletion(lit: Literal) extends TypeMemberCompletion(lit.value.tpe) {
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(memberNames)
+ case _ => memberNames
+ }
+ }
+
+ class ImportCompletion(tp: Type) extends TypeMemberCompletion(tp) {
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(members filterNot (_.isSetter) map tos)
+ case _ => super.completions(verbosity)
+ }
+ }
+
+ // not for completion but for excluding
+ object anyref extends TypeMemberCompletion(AnyRefClass.tpe) { }
+
+ // the unqualified vals/defs/etc visible in the repl
+ object ids extends CompletionAware {
+ override def completions(verbosity: Int) = repl.unqualifiedIds ::: List("classOf")
+ // we try to use the compiler and fall back on reflection if necessary
+ // (which at present is for anything defined in the repl session.)
+ override def follow(id: String) =
+ if (completions(0) contains id) {
+ for (clazz <- repl clazzForIdent id) yield {
+ // XXX The isMemberClass check is a workaround for the crasher described
+ // in the comments of #3431. The issue as described by iulian is:
+ //
+ // Inner classes exist as symbols
+ // inside their enclosing class, but also inside their package, with a mangled
+ // name (A$B). The mangled names should never be loaded, and exist only for the
+ // optimizer, which sometimes cannot get the right symbol, but it doesn't care
+ // and loads the bytecode anyway.
+ //
+ // So this solution is incorrect, but in the short term the simple fix is
+ // to skip the compiler any time completion is requested on a nested class.
+ if (clazz.isMemberClass) new InstanceCompletion(clazz)
+ else (typeOf(clazz.getName) map TypeMemberCompletion.apply) getOrElse new InstanceCompletion(clazz)
+ }
+ }
+ else None
+ }
+
+ // wildcard imports in the repl like "import global._" or "import String._"
+ private def imported = repl.wildcardImportedTypes map TypeMemberCompletion.imported
+
+ // literal Ints, Strings, etc.
+ object literals extends CompletionAware {
+ def simpleParse(code: String): Tree = {
+ val unit = new CompilationUnit(new util.BatchSourceFile("<console>", code))
+ val scanner = new syntaxAnalyzer.UnitParser(unit)
+ val tss = scanner.templateStatSeq(false)._2
+
+ if (tss.size == 1) tss.head else EmptyTree
+ }
+
+ def completions(verbosity: Int) = Nil
+
+ override def follow(id: String) = simpleParse(id) match {
+ case x: Literal => Some(new LiteralCompletion(x))
+ case _ => None
+ }
+ }
+
+ // top level packages
+ object rootClass extends TypeMemberCompletion(RootClass.tpe) { }
+ // members of Predef
+ object predef extends TypeMemberCompletion(PredefModule.tpe) {
+ override def excludeEndsWith = super.excludeEndsWith ++ List("Wrapper", "ArrayOps")
+ override def excludeStartsWith = super.excludeStartsWith ++ List("wrap")
+ override def excludeNames = anyref.methodNames
+
+ override def exclude(name: String) = super.exclude(name) || (
+ (name contains "2")
+ )
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => Nil
+ case _ => super.completions(verbosity)
+ }
+ }
+ // members of scala.*
+ object scalalang extends PackageCompletion(ScalaPackage.tpe) {
+ def arityClasses = List("Product", "Tuple", "Function")
+ def skipArity(name: String) = arityClasses exists (x => name != x && (name startsWith x))
+ override def exclude(name: String) = super.exclude(name) || (
+ skipArity(name)
+ )
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(packageNames ++ aliasNames)
+ case _ => super.completions(verbosity)
+ }
+ }
+ // members of java.lang.*
+ object javalang extends PackageCompletion(JavaLangPackage.tpe) {
+ override lazy val excludeEndsWith = super.excludeEndsWith ++ List("Exception", "Error")
+ override lazy val excludeStartsWith = super.excludeStartsWith ++ List("CharacterData")
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(packageNames)
+ case _ => super.completions(verbosity)
+ }
+ }
+
+ // the list of completion aware objects which should be consulted
+ lazy val topLevelBase: List[CompletionAware] = List(ids, rootClass, predef, scalalang, javalang, literals)
+ def topLevel = topLevelBase ++ imported
+
+ // the first tier of top level objects (doesn't include file completion)
+ def topLevelFor(parsed: Parsed) = topLevel flatMap (_ completionsFor parsed)
+
+ // the most recent result
+ def lastResult = Forwarder(() => ids follow repl.mostRecentVar)
+
+ def lastResultFor(parsed: Parsed) = {
+ /** The logic is a little tortured right now because normally '.' is
+ * ignored as a delimiter, but on .<tab> it needs to be propagated.
+ */
+ val xs = lastResult completionsFor parsed
+ if (parsed.isEmpty) xs map ("." + _) else xs
+ }
+
+ // chasing down results which won't parse
+ def execute(line: String): Option[Any] = {
+ val parsed = Parsed(line)
+ def noDotOrSlash = line forall (ch => ch != '.' && ch != '/')
+
+ if (noDotOrSlash) None // we defer all unqualified ids to the repl.
+ else {
+ (ids executionFor parsed) orElse
+ (rootClass executionFor parsed) orElse
+ (FileCompletion executionFor line)
+ }
+ }
+
+ // generic interface for querying (e.g. interpreter loop, testing)
+ def completions(buf: String): List[String] =
+ topLevelFor(Parsed.dotted(buf + ".", buf.length + 1))
+
+ // jline's entry point
+ lazy val jline: ArgumentCompletor =
+ returning(new ArgumentCompletor(new JLineCompletion, new JLineDelimiter))(_ setStrict false)
+
+ /** This gets a little bit hairy. It's no small feat delegating everything
+ * and also keeping track of exactly where the cursor is and where it's supposed
+ * to end up. The alternatives mechanism is a little hacky: if there is an empty
+ * string in the list of completions, that means we are expanding a unique
+ * completion, so don't update the "last" buffer because it'll be wrong.
+ */
+ class JLineCompletion extends Completor {
+ // For recording the buffer on the last tab hit
+ private var lastBuf: String = ""
+ private var lastCursor: Int = -1
+
+ // Does this represent two consecutive tabs?
+ def isConsecutiveTabs(buf: String, cursor: Int) = cursor == lastCursor && buf == lastBuf
+
+ // Longest common prefix
+ def commonPrefix(xs: List[String]) =
+ if (xs.isEmpty) ""
+ else xs.reduceLeft(_ zip _ takeWhile (x => x._1 == x._2) map (_._1) mkString)
+
+ // This is jline's entry point for completion.
+ override def complete(_buf: String, cursor: Int, candidates: java.util.List[java.lang.String]): Int = {
+ val buf = onull(_buf)
+ verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0
+ DBG("complete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity))
+
+ // we don't try lower priority completions unless higher ones return no results.
+ def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Int] = {
+ completionFunction(p) match {
+ case Nil => None
+ case xs =>
+ // modify in place and return the position
+ xs.foreach(x => candidates.add(x))
+
+ // update the last buffer unless this is an alternatives list
+ if (xs contains "") Some(p.cursor)
+ else {
+ val advance = commonPrefix(xs)
+ lastCursor = p.position + advance.length
+ lastBuf = (buf take p.position) + advance
+
+ DBG("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(p, lastBuf, lastCursor, p.position))
+ Some(p.position)
+ }
+ }
+ }
+
+ def mkDotted = Parsed.dotted(buf, cursor) withVerbosity verbosity
+ def mkUndelimited = Parsed.undelimited(buf, cursor) withVerbosity verbosity
+
+ // a single dot is special cased to completion on the previous result
+ def lastResultCompletion =
+ if (!looksLikeInvocation(buf)) None
+ else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor)
+
+ def regularCompletion = tryCompletion(mkDotted, topLevelFor)
+ def fileCompletion = tryCompletion(mkUndelimited, FileCompletion completionsFor _.buffer)
+
+ (lastResultCompletion orElse regularCompletion orElse fileCompletion) getOrElse cursor
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/repl/SparkCompletionOutput.scala b/core/src/main/scala/spark/repl/SparkCompletionOutput.scala
new file mode 100644
index 0000000000..5ac46e3412
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkCompletionOutput.scala
@@ -0,0 +1,92 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+/** This has a lot of duplication with other methods in Symbols and Types,
+ * but repl completion utility is very sensitive to precise output. Best
+ * thing would be to abstract an interface for how such things are printed,
+ * as is also in progress with error messages.
+ */
+trait SparkCompletionOutput {
+ self: SparkCompletion =>
+
+ import global._
+ import definitions.{ NothingClass, AnyClass, isTupleType, isFunctionType, isRepeatedParamType }
+
+ /** Reducing fully qualified noise for some common packages.
+ */
+ val typeTransforms = List(
+ "java.lang." -> "",
+ "scala.collection.immutable." -> "immutable.",
+ "scala.collection.mutable." -> "mutable.",
+ "scala.collection.generic." -> "generic."
+ )
+
+ def quietString(tp: String): String =
+ typeTransforms.foldLeft(tp) {
+ case (str, (prefix, replacement)) =>
+ if (str startsWith prefix) replacement + (str stripPrefix prefix)
+ else str
+ }
+
+ class MethodSymbolOutput(method: Symbol) {
+ val pkg = method.ownerChain find (_.isPackageClass) map (_.fullName) getOrElse ""
+
+ def relativize(str: String): String = quietString(str stripPrefix (pkg + "."))
+ def relativize(tp: Type): String = relativize(tp.normalize.toString)
+ def relativize(sym: Symbol): String = relativize(sym.info)
+
+ def braceList(tparams: List[String]) = if (tparams.isEmpty) "" else (tparams map relativize).mkString("[", ", ", "]")
+ def parenList(params: List[Any]) = params.mkString("(", ", ", ")")
+
+ def methodTypeToString(mt: MethodType) =
+ (mt.paramss map paramsString mkString "") + ": " + relativize(mt.finalResultType)
+
+ def typeToString(tp: Type): String = relativize(
+ tp match {
+ case x if isFunctionType(x) => functionString(x)
+ case x if isTupleType(x) => tupleString(x)
+ case x if isRepeatedParamType(x) => typeToString(x.typeArgs.head) + "*"
+ case mt @ MethodType(_, _) => methodTypeToString(mt)
+ case x => x.toString
+ }
+ )
+
+ def tupleString(tp: Type) = parenList(tp.normalize.typeArgs map relativize)
+ def functionString(tp: Type) = tp.normalize.typeArgs match {
+ case List(t, r) => t + " => " + r
+ case xs => parenList(xs.init) + " => " + xs.last
+ }
+
+ def tparamsString(tparams: List[Symbol]) = braceList(tparams map (_.defString))
+ def paramsString(params: List[Symbol]) = {
+ def paramNameString(sym: Symbol) = if (sym.isSynthetic) "" else sym.nameString + ": "
+ def paramString(sym: Symbol) = paramNameString(sym) + typeToString(sym.info.normalize)
+
+ val isImplicit = params.nonEmpty && params.head.isImplicit
+ val strs = (params map paramString) match {
+ case x :: xs if isImplicit => ("implicit " + x) :: xs
+ case xs => xs
+ }
+ parenList(strs)
+ }
+
+ def methodString() =
+ method.keyString + " " + method.nameString + (method.info.normalize match {
+ case PolyType(Nil, resType) => ": " + typeToString(resType) // nullary method
+ case PolyType(tparams, resType) => tparamsString(tparams) + typeToString(resType)
+ case mt @ MethodType(_, _) => methodTypeToString(mt)
+ case x =>
+ DBG("methodString(): %s / %s".format(x.getClass, x))
+ x.toString
+ })
+ }
+}
diff --git a/core/src/main/scala/spark/repl/SparkInteractiveReader.scala b/core/src/main/scala/spark/repl/SparkInteractiveReader.scala
new file mode 100644
index 0000000000..4f5a0a6fa0
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkInteractiveReader.scala
@@ -0,0 +1,60 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import scala.util.control.Exception._
+
+/** Reads lines from an input stream */
+trait SparkInteractiveReader {
+ import SparkInteractiveReader._
+ import java.io.IOException
+
+ protected def readOneLine(prompt: String): String
+ val interactive: Boolean
+
+ def readLine(prompt: String): String = {
+ def handler: Catcher[String] = {
+ case e: IOException if restartSystemCall(e) => readLine(prompt)
+ }
+ catching(handler) { readOneLine(prompt) }
+ }
+
+ // override if history is available
+ def history: Option[History] = None
+ def historyList = history map (_.asList) getOrElse Nil
+
+ // override if completion is available
+ def completion: Option[SparkCompletion] = None
+
+ // hack necessary for OSX jvm suspension because read calls are not restarted after SIGTSTP
+ private def restartSystemCall(e: Exception): Boolean =
+ Properties.isMac && (e.getMessage == msgEINTR)
+}
+
+
+object SparkInteractiveReader {
+ val msgEINTR = "Interrupted system call"
+ private val exes = List(classOf[Exception], classOf[NoClassDefFoundError])
+
+ def createDefault(): SparkInteractiveReader = createDefault(null)
+
+ /** Create an interactive reader. Uses <code>JLineReader</code> if the
+ * library is available, but otherwise uses a <code>SimpleReader</code>.
+ */
+ def createDefault(interpreter: SparkInterpreter): SparkInteractiveReader =
+ try new SparkJLineReader(interpreter)
+ catch {
+ case e @ (_: Exception | _: NoClassDefFoundError) =>
+ // println("Failed to create SparkJLineReader(%s): %s".format(interpreter, e))
+ new SparkSimpleReader
+ }
+}
+
diff --git a/core/src/main/scala/spark/repl/SparkInterpreter.scala b/core/src/main/scala/spark/repl/SparkInterpreter.scala
new file mode 100644
index 0000000000..10ea346658
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkInterpreter.scala
@@ -0,0 +1,1395 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Martin Odersky
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import Predef.{ println => _, _ }
+import java.io.{ File, IOException, PrintWriter, StringWriter, Writer }
+import File.pathSeparator
+import java.lang.{ Class, ClassLoader }
+import java.net.{ MalformedURLException, URL }
+import java.lang.reflect
+import reflect.InvocationTargetException
+import java.util.UUID
+
+import scala.PartialFunction.{ cond, condOpt }
+import scala.tools.util.PathResolver
+import scala.reflect.Manifest
+import scala.collection.mutable
+import scala.collection.mutable.{ ListBuffer, HashSet, HashMap, ArrayBuffer }
+import scala.collection.immutable.Set
+import scala.tools.nsc.util.ScalaClassLoader
+import ScalaClassLoader.URLClassLoader
+import scala.util.control.Exception.{ Catcher, catching, ultimately, unwrapping }
+
+import io.{ PlainFile, VirtualDirectory }
+import reporters.{ ConsoleReporter, Reporter }
+import symtab.{ Flags, Names }
+import util.{ SourceFile, BatchSourceFile, ScriptSourceFile, ClassPath, Chars, stringFromWriter }
+import scala.reflect.NameTransformer
+import scala.tools.nsc.{ InterpreterResults => IR }
+import interpreter._
+import SparkInterpreter._
+
+import spark.HttpServer
+import spark.Utils
+
+/** <p>
+ * An interpreter for Scala code.
+ * </p>
+ * <p>
+ * The main public entry points are <code>compile()</code>,
+ * <code>interpret()</code>, and <code>bind()</code>.
+ * The <code>compile()</code> method loads a
+ * complete Scala file. The <code>interpret()</code> method executes one
+ * line of Scala code at the request of the user. The <code>bind()</code>
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ * </p>
+ * <p>
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ * </p>
+ * <p>
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports a single member named "scala_repl_result". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ * </p>
+ * <p>
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ * </p>
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+class SparkInterpreter(val settings: Settings, out: PrintWriter) {
+ repl =>
+
+ def println(x: Any) = {
+ out.println(x)
+ out.flush()
+ }
+
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this() = this(new Settings())
+
+ val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
+
+ /** Local directory to save .class files too */
+ val outputDir = {
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = System.getProperty("spark.repl.classdir", tmp)
+ Utils.createTempDir(rootDir)
+ }
+ if (SPARK_DEBUG_REPL) {
+ println("Output directory: " + outputDir)
+ }
+
+ /** Scala compiler virtual directory for outputDir */
+ //val virtualDirectory = new VirtualDirectory("(memory)", None)
+ val virtualDirectory = new PlainFile(outputDir)
+
+ /** Jetty server that will serve our classes to worker nodes */
+ val classServer = new HttpServer(outputDir)
+
+ // Start the classServer and store its URI in a spark system property
+ // (which will be passed to executors so that they can connect to it)
+ classServer.start()
+ System.setProperty("spark.repl.class.uri", classServer.uri)
+ if (SPARK_DEBUG_REPL) {
+ println("Class server started, URI = " + classServer.uri)
+ }
+
+ /** reporter */
+ object reporter extends ConsoleReporter(settings, null, out) {
+ override def printMessage(msg: String) {
+ out println clean(msg)
+ out.flush()
+ }
+ }
+
+ /** We're going to go to some trouble to initialize the compiler asynchronously.
+ * It's critical that nothing call into it until it's been initialized or we will
+ * run into unrecoverable issues, but the perceived repl startup time goes
+ * through the roof if we wait for it. So we initialize it with a future and
+ * use a lazy val to ensure that any attempt to use the compiler object waits
+ * on the future.
+ */
+ private val _compiler: Global = newCompiler(settings, reporter)
+ private def _initialize(): Boolean = {
+ val source = """
+ |// this is assembled to force the loading of approximately the
+ |// classes which will be loaded on the first expression anyway.
+ |class $repl_$init {
+ | val x = "abc".reverse.length + (5 max 5)
+ | scala.runtime.ScalaRunTime.stringOf(x)
+ |}
+ |""".stripMargin
+
+ try {
+ new _compiler.Run() compileSources List(new BatchSourceFile("<init>", source))
+ if (isReplDebug || settings.debug.value)
+ println("Repl compiler initialized.")
+ true
+ }
+ catch {
+ case MissingRequirementError(msg) => println("""
+ |Failed to initialize compiler: %s not found.
+ |** Note that as of 2.8 scala does not assume use of the java classpath.
+ |** For the old behavior pass -usejavacp to scala, or if using a Settings
+ |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(msg)
+ )
+ false
+ }
+ }
+
+ // set up initialization future
+ private var _isInitialized: () => Boolean = null
+ def initialize() = synchronized {
+ if (_isInitialized == null)
+ _isInitialized = scala.concurrent.ops future _initialize()
+ }
+
+ /** the public, go through the future compiler */
+ lazy val compiler: Global = {
+ initialize()
+
+ // blocks until it is ; false means catastrophic failure
+ if (_isInitialized()) _compiler
+ else null
+ }
+
+ import compiler.{ Traverser, CompilationUnit, Symbol, Name, Type }
+ import compiler.{
+ Tree, TermTree, ValOrDefDef, ValDef, DefDef, Assign, ClassDef,
+ ModuleDef, Ident, Select, TypeDef, Import, MemberDef, DocDef,
+ ImportSelector, EmptyTree, NoType }
+ import compiler.{ nme, newTermName, newTypeName }
+ import nme.{
+ INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX, INTERPRETER_LINE_PREFIX,
+ INTERPRETER_IMPORT_WRAPPER, INTERPRETER_WRAPPER_SUFFIX, USCOREkw
+ }
+
+ import compiler.definitions
+ import definitions.{ EmptyPackage, getMember }
+
+ /** whether to print out result lines */
+ private[repl] var printResults: Boolean = true
+
+ /** Temporarily be quiet */
+ def beQuietDuring[T](operation: => T): T = {
+ val wasPrinting = printResults
+ ultimately(printResults = wasPrinting) {
+ printResults = false
+ operation
+ }
+ }
+
+ /** whether to bind the lastException variable */
+ private var bindLastException = true
+
+ /** Temporarily stop binding lastException */
+ def withoutBindingLastException[T](operation: => T): T = {
+ val wasBinding = bindLastException
+ ultimately(bindLastException = wasBinding) {
+ bindLastException = false
+ operation
+ }
+ }
+
+ /** interpreter settings */
+ lazy val isettings = new SparkInterpreterSettings(this)
+
+ /** Instantiate a compiler. Subclasses can override this to
+ * change the compiler class used by this interpreter. */
+ protected def newCompiler(settings: Settings, reporter: Reporter) = {
+ settings.outputDirs setSingleOutput virtualDirectory
+ new Global(settings, reporter)
+ }
+
+ /** the compiler's classpath, as URL's */
+ lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
+ It would also be possible to create a new class loader for each command
+ to interpret. The advantages of the current approach are:
+
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
+
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ private var _classLoader: ClassLoader = null
+ def resetClassLoader() = _classLoader = makeClassLoader()
+ def classLoader: ClassLoader = {
+ if (_classLoader == null)
+ resetClassLoader()
+
+ _classLoader
+ }
+ private def makeClassLoader(): ClassLoader = {
+ /*
+ val parent =
+ if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath
+ else new URLClassLoader(compilerClasspath, parentClassLoader)
+
+ new AbstractFileClassLoader(virtualDirectory, parent)
+ */
+ val parent =
+ if (parentClassLoader == null)
+ new java.net.URLClassLoader(compilerClasspath.toArray)
+ else
+ new java.net.URLClassLoader(compilerClasspath.toArray,
+ parentClassLoader)
+ val virtualDirUrl = new URL("file://" + virtualDirectory.path + "/")
+ new java.net.URLClassLoader(Array(virtualDirUrl), parent)
+ }
+
+ private def loadByName(s: String): Class[_] = // (classLoader tryToInitializeClass s).get
+ Class.forName(s, true, classLoader)
+
+ private def methodByName(c: Class[_], name: String): reflect.Method =
+ c.getMethod(name, classOf[Object])
+
+ protected def parentClassLoader: ClassLoader = this.getClass.getClassLoader()
+ def getInterpreterClassLoader() = classLoader
+
+ // Set the current Java "context" class loader to this interpreter's class loader
+ def setContextClassLoader() = Thread.currentThread.setContextClassLoader(classLoader)
+
+ /** the previous requests this interpreter has processed */
+ private val prevRequests = new ArrayBuffer[Request]()
+ private val usedNameMap = new HashMap[Name, Request]()
+ private val boundNameMap = new HashMap[Name, Request]()
+ private def allHandlers = prevRequests.toList flatMap (_.handlers)
+ private def allReqAndHandlers = prevRequests.toList flatMap (req => req.handlers map (req -> _))
+
+ def printAllTypeOf = {
+ prevRequests foreach { req =>
+ req.typeOf foreach { case (k, v) => Console.println(k + " => " + v) }
+ }
+ }
+
+ /** Most recent tree handled which wasn't wholly synthetic. */
+ private def mostRecentlyHandledTree: Option[Tree] = {
+ for {
+ req <- prevRequests.reverse
+ handler <- req.handlers.reverse
+ name <- handler.generatesValue
+ if !isSynthVarName(name)
+ } return Some(handler.member)
+
+ None
+ }
+
+ def recordRequest(req: Request) {
+ def tripart[T](set1: Set[T], set2: Set[T]) = {
+ val intersect = set1 intersect set2
+ List(set1 -- intersect, intersect, set2 -- intersect)
+ }
+
+ prevRequests += req
+ req.usedNames foreach (x => usedNameMap(x) = req)
+ req.boundNames foreach (x => boundNameMap(x) = req)
+
+ // XXX temporarily putting this here because of tricky initialization order issues
+ // so right now it's not bound until after you issue a command.
+ if (prevRequests.size == 1)
+ quietBind("settings", "spark.repl.SparkInterpreterSettings", isettings)
+
+ // println("\n s1 = %s\n s2 = %s\n s3 = %s".format(
+ // tripart(usedNameMap.keysIterator.toSet, boundNameMap.keysIterator.toSet): _*
+ // ))
+ }
+
+ private def keyList[T](x: collection.Map[T, _]): List[T] = x.keys.toList sortBy (_.toString)
+ def allUsedNames = keyList(usedNameMap)
+ def allBoundNames = keyList(boundNameMap)
+ def allSeenTypes = prevRequests.toList flatMap (_.typeOf.values.toList) distinct
+ def allValueGeneratingNames = allHandlers flatMap (_.generatesValue)
+ def allImplicits = partialFlatMap(allHandlers) {
+ case x: MemberHandler if x.definesImplicit => x.boundNames
+ }
+
+ /** Generates names pre0, pre1, etc. via calls to apply method */
+ class NameCreator(pre: String) {
+ private var x = -1
+ var mostRecent: String = null
+
+ def apply(): String = {
+ x += 1
+ val name = pre + x.toString
+ // make sure we don't overwrite their unwisely named res3 etc.
+ mostRecent =
+ if (allBoundNames exists (_.toString == name)) apply()
+ else name
+
+ mostRecent
+ }
+ def reset(): Unit = x = -1
+ def didGenerate(name: String) =
+ (name startsWith pre) && ((name drop pre.length) forall (_.isDigit))
+ }
+
+ /** allocate a fresh line name */
+ private lazy val lineNameCreator = new NameCreator(INTERPRETER_LINE_PREFIX)
+
+ /** allocate a fresh var name */
+ private lazy val varNameCreator = new NameCreator(INTERPRETER_VAR_PREFIX)
+
+ /** allocate a fresh internal variable name */
+ private lazy val synthVarNameCreator = new NameCreator(INTERPRETER_SYNTHVAR_PREFIX)
+
+ /** Check if a name looks like it was generated by varNameCreator */
+ private def isGeneratedVarName(name: String): Boolean = varNameCreator didGenerate name
+ private def isSynthVarName(name: String): Boolean = synthVarNameCreator didGenerate name
+ private def isSynthVarName(name: Name): Boolean = synthVarNameCreator didGenerate name.toString
+
+ def getVarName = varNameCreator()
+ def getSynthVarName = synthVarNameCreator()
+
+ /** Truncate a string if it is longer than isettings.maxPrintString */
+ private def truncPrintString(str: String): String = {
+ val maxpr = isettings.maxPrintString
+ val trailer = "..."
+
+ if (maxpr <= 0 || str.length <= maxpr) str
+ else str.substring(0, maxpr-3) + trailer
+ }
+
+ /** Clean up a string for output */
+ private def clean(str: String) = truncPrintString(
+ if (isettings.unwrapStrings && !SPARK_DEBUG_REPL) stripWrapperGunk(str)
+ else str
+ )
+
+ /** Heuristically strip interpreter wrapper prefixes
+ * from an interpreter output string.
+ * MATEI: Copied from interpreter package object
+ */
+ def stripWrapperGunk(str: String): String = {
+ val wrapregex = """(line[0-9]+\$object[$.])?(\$?VAL.?)*(\$iwC?(.this)?[$.])*"""
+ str.replaceAll(wrapregex, "")
+ }
+
+ /** Indent some code by the width of the scala> prompt.
+ * This way, compiler error messages read better.
+ */
+ private final val spaces = List.fill(7)(" ").mkString
+ def indentCode(code: String) = {
+ /** Heuristic to avoid indenting and thereby corrupting """-strings and XML literals. */
+ val noIndent = (code contains "\n") && (List("\"\"\"", "</", "/>") exists (code contains _))
+ stringFromWriter(str =>
+ for (line <- code.lines) {
+ if (!noIndent)
+ str.print(spaces)
+
+ str.print(line + "\n")
+ str.flush()
+ })
+ }
+ def indentString(s: String) = s split "\n" map (spaces + _ + "\n") mkString
+
+ implicit def name2string(name: Name) = name.toString
+
+ /** Compute imports that allow definitions from previous
+ * requests to be visible in a new request. Returns
+ * three pieces of related code:
+ *
+ * 1. An initial code fragment that should go before
+ * the code of the new request.
+ *
+ * 2. A code fragment that should go after the code
+ * of the new request.
+ *
+ * 3. An access path which can be traverested to access
+ * any bindings inside code wrapped by #1 and #2 .
+ *
+ * The argument is a set of Names that need to be imported.
+ *
+ * Limitations: This method is not as precise as it could be.
+ * (1) It does not process wildcard imports to see what exactly
+ * they import.
+ * (2) If it imports any names from a request, it imports all
+ * of them, which is not really necessary.
+ * (3) It imports multiple same-named implicits, but only the
+ * last one imported is actually usable.
+ */
+ private case class ComputedImports(prepend: String, append: String, access: String)
+ private def importsCode(wanted: Set[Name]): ComputedImports = {
+ /** Narrow down the list of requests from which imports
+ * should be taken. Removes requests which cannot contribute
+ * useful imports for the specified set of wanted names.
+ */
+ case class ReqAndHandler(req: Request, handler: MemberHandler) { }
+
+ def reqsToUse: List[ReqAndHandler] = {
+ /** Loop through a list of MemberHandlers and select which ones to keep.
+ * 'wanted' is the set of names that need to be imported.
+ */
+ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
+ val isWanted = wanted contains _
+ // Single symbol imports might be implicits! See bug #1752. Rather than
+ // try to finesse this, we will mimic all imports for now.
+ def keepHandler(handler: MemberHandler) = handler match {
+ case _: ImportHandler => true
+ case x => x.definesImplicit || (x.boundNames exists isWanted)
+ }
+
+ reqs match {
+ case Nil => Nil
+ case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted)
+ case rh :: rest =>
+ val importedNames = rh.handler match { case x: ImportHandler => x.importedNames ; case _ => Nil }
+ import rh.handler._
+ val newWanted = wanted ++ usedNames -- boundNames -- importedNames
+ rh :: select(rest, newWanted)
+ }
+ }
+
+ /** Flatten the handlers out and pair each with the original request */
+ select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse
+ }
+
+ val code, trailingLines, accessPath = new StringBuffer
+ val currentImps = HashSet[Name]()
+
+ // add code for a new object to hold some imports
+ def addWrapper() {
+ /*
+ val impname = INTERPRETER_IMPORT_WRAPPER
+ code append "object %s {\n".format(impname)
+ trailingLines append "}\n"
+ accessPath append ("." + impname)
+ currentImps.clear
+ */
+ val impname = INTERPRETER_IMPORT_WRAPPER
+ code.append("@serializable class " + impname + "C {\n")
+ trailingLines.append("}\nval " + impname + " = new " + impname + "C;\n")
+ accessPath.append("." + impname)
+ currentImps.clear
+ }
+
+ addWrapper()
+
+ // loop through previous requests, adding imports for each one
+ for (ReqAndHandler(req, handler) <- reqsToUse) {
+ handler match {
+ // If the user entered an import, then just use it; add an import wrapping
+ // level if the import might conflict with some other import
+ case x: ImportHandler =>
+ if (x.importsWildcard || (currentImps exists (x.importedNames contains _)))
+ addWrapper()
+
+ code append (x.member.toString + "\n")
+
+ // give wildcard imports a import wrapper all to their own
+ if (x.importsWildcard) addWrapper()
+ else currentImps ++= x.importedNames
+
+ // For other requests, import each bound variable.
+ // import them explicitly instead of with _, so that
+ // ambiguity errors will not be generated. Also, quote
+ // the name of the variable, so that we don't need to
+ // handle quoting keywords separately.
+ case x =>
+ for (imv <- x.boundNames) {
+ // MATEI: Commented this check out because it was messing up for case classes
+ // (trying to import them twice within the same wrapper), but that is more likely
+ // due to a miscomputation of names that makes the code think they're unique.
+ // Need to evaluate whether having so many wrappers is a bad thing.
+ /*if (currentImps contains imv) */ addWrapper()
+
+ code.append("val " + req.objectName + "$VAL = " + req.objectName + ".INSTANCE;\n")
+ code.append("import " + req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n")
+
+ //code append ("import %s\n" format (req fullPath imv))
+ currentImps += imv
+ }
+ }
+ }
+ // add one extra wrapper, to prevent warnings in the common case of
+ // redefining the value bound in the last interpreter request.
+ addWrapper()
+ ComputedImports(code.toString, trailingLines.toString, accessPath.toString)
+ }
+
+ /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
+ private def parse(line: String): Option[List[Tree]] = {
+ var justNeedsMore = false
+ reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) {
+ // simple parse: just parse it, nothing else
+ def simpleParse(code: String): List[Tree] = {
+ reporter.reset
+ val unit = new CompilationUnit(new BatchSourceFile("<console>", code))
+ val scanner = new compiler.syntaxAnalyzer.UnitParser(unit)
+
+ scanner.templateStatSeq(false)._2
+ }
+ val trees = simpleParse(line)
+
+ if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop
+ else if (justNeedsMore) None
+ else Some(trees)
+ }
+ }
+
+ /** Compile an nsc SourceFile. Returns true if there are
+ * no compilation errors, or false otherwise.
+ */
+ def compileSources(sources: SourceFile*): Boolean = {
+ reporter.reset
+ new compiler.Run() compileSources sources.toList
+ !reporter.hasErrors
+ }
+
+ /** Compile a string. Returns true if there are no
+ * compilation errors, or false otherwise.
+ */
+ def compileString(code: String): Boolean =
+ compileSources(new BatchSourceFile("<script>", code))
+
+ def compileAndSaveRun(label: String, code: String) = {
+ if (SPARK_DEBUG_REPL)
+ println(code)
+ if (isReplDebug) {
+ parse(code) match {
+ case Some(trees) => trees foreach (t => DBG(compiler.asCompactString(t)))
+ case _ => DBG("Parse error:\n\n" + code)
+ }
+ }
+ val run = new compiler.Run()
+ run.compileSources(List(new BatchSourceFile(label, code)))
+ run
+ }
+
+ /** Build a request from the user. <code>trees</code> is <code>line</code>
+ * after being parsed.
+ */
+ private def buildRequest(line: String, lineName: String, trees: List[Tree]): Request =
+ new Request(line, lineName, trees)
+
+ private def chooseHandler(member: Tree): MemberHandler = member match {
+ case member: DefDef => new DefHandler(member)
+ case member: ValDef => new ValHandler(member)
+ case member@Assign(Ident(_), _) => new AssignHandler(member)
+ case member: ModuleDef => new ModuleHandler(member)
+ case member: ClassDef => new ClassHandler(member)
+ case member: TypeDef => new TypeAliasHandler(member)
+ case member: Import => new ImportHandler(member)
+ case DocDef(_, documented) => chooseHandler(documented)
+ case member => new GenericHandler(member)
+ }
+
+ private def requestFromLine(line: String, synthetic: Boolean): Either[IR.Result, Request] = {
+ val trees = parse(indentCode(line)) match {
+ case None => return Left(IR.Incomplete)
+ case Some(Nil) => return Left(IR.Error) // parse error or empty input
+ case Some(trees) => trees
+ }
+
+ // use synthetic vars to avoid filling up the resXX slots
+ def varName = if (synthetic) getSynthVarName else getVarName
+
+ // Treat a single bare expression specially. This is necessary due to it being hard to
+ // modify code at a textual level, and it being hard to submit an AST to the compiler.
+ if (trees.size == 1) trees.head match {
+ case _:Assign => // we don't want to include assignments
+ case _:TermTree | _:Ident | _:Select => // ... but do want these as valdefs.
+ return requestFromLine("val %s =\n%s".format(varName, line), synthetic)
+ case _ =>
+ }
+
+ // figure out what kind of request
+ Right(buildRequest(line, lineNameCreator(), trees))
+ }
+
+ /** <p>
+ * Interpret one line of input. All feedback, including parse errors
+ * and evaluation results, are printed via the supplied compiler's
+ * reporter. Values defined are available for future interpreted
+ * strings.
+ * </p>
+ * <p>
+ * The return value is whether the line was interpreter successfully,
+ * e.g. that there were no parse errors.
+ * </p>
+ *
+ * @param line ...
+ * @return ...
+ */
+ def interpret(line: String): IR.Result = interpret(line, false)
+ def interpret(line: String, synthetic: Boolean): IR.Result = {
+ def loadAndRunReq(req: Request) = {
+ val (result, succeeded) = req.loadAndRun
+ if (printResults || !succeeded)
+ out print clean(result)
+
+ // book-keeping
+ if (succeeded && !synthetic)
+ recordRequest(req)
+
+ if (succeeded) IR.Success
+ else IR.Error
+ }
+
+ if (compiler == null) IR.Error
+ else requestFromLine(line, synthetic) match {
+ case Left(result) => result
+ case Right(req) =>
+ // null indicates a disallowed statement type; otherwise compile and
+ // fail if false (implying e.g. a type error)
+ if (req == null || !req.compile) IR.Error
+ else loadAndRunReq(req)
+ }
+ }
+
+ /** A name creator used for objects created by <code>bind()</code>. */
+ private lazy val newBinder = new NameCreator("binder")
+
+ /** Bind a specified name to a specified value. The name may
+ * later be used by expressions passed to interpret.
+ *
+ * @param name the variable name to bind
+ * @param boundType the type of the variable, as a string
+ * @param value the object value to bind to it
+ * @return an indication of whether the binding succeeded
+ */
+ def bind(name: String, boundType: String, value: Any): IR.Result = {
+ val binderName = newBinder()
+
+ compileString("""
+ |object %s {
+ | var value: %s = _
+ | def set(x: Any) = value = x.asInstanceOf[%s]
+ |}
+ """.stripMargin.format(binderName, boundType, boundType))
+
+ val binderObject = loadByName(binderName)
+ val setterMethod = methodByName(binderObject, "set")
+
+ setterMethod.invoke(null, value.asInstanceOf[AnyRef])
+ interpret("val %s = %s.value".format(name, binderName))
+ }
+
+ def quietBind(name: String, boundType: String, value: Any): IR.Result =
+ beQuietDuring { bind(name, boundType, value) }
+
+ /** Reset this interpreter, forgetting all user-specified requests. */
+ def reset() {
+ //virtualDirectory.clear
+ virtualDirectory.delete
+ virtualDirectory.create
+ resetClassLoader()
+ lineNameCreator.reset()
+ varNameCreator.reset()
+ prevRequests.clear
+ }
+
+ /** <p>
+ * This instance is no longer needed, so release any resources
+ * it is using. The reporter's output gets flushed.
+ * </p>
+ */
+ def close() {
+ reporter.flush
+ classServer.stop()
+ }
+
+ /** A traverser that finds all mentioned identifiers, i.e. things
+ * that need to be imported. It might return extra names.
+ */
+ private class ImportVarsTraverser extends Traverser {
+ val importVars = new HashSet[Name]()
+
+ override def traverse(ast: Tree) = ast match {
+ // XXX this is obviously inadequate but it's going to require some effort
+ // to get right.
+ case Ident(name) if !(name.toString startsWith "x$") => importVars += name
+ case _ => super.traverse(ast)
+ }
+ }
+
+ /** Class to handle one member among all the members included
+ * in a single interpreter request.
+ */
+ private sealed abstract class MemberHandler(val member: Tree) {
+ lazy val usedNames: List[Name] = {
+ val ivt = new ImportVarsTraverser()
+ ivt traverse member
+ ivt.importVars.toList
+ }
+ def boundNames: List[Name] = Nil
+ val definesImplicit = cond(member) {
+ case tree: MemberDef => tree.mods hasFlag Flags.IMPLICIT
+ }
+ def generatesValue: Option[Name] = None
+
+ def extraCodeToEvaluate(req: Request, code: PrintWriter) { }
+ def resultExtractionCode(req: Request, code: PrintWriter) { }
+
+ override def toString = "%s(used = %s)".format(this.getClass.toString split '.' last, usedNames)
+ }
+
+ private class GenericHandler(member: Tree) extends MemberHandler(member)
+
+ private class ValHandler(member: ValDef) extends MemberHandler(member) {
+ lazy val ValDef(mods, vname, _, _) = member
+ lazy val prettyName = NameTransformer.decode(vname)
+ lazy val isLazy = mods hasFlag Flags.LAZY
+
+ override lazy val boundNames = List(vname)
+ override def generatesValue = Some(vname)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ val isInternal = isGeneratedVarName(vname) && req.typeOfEnc(vname) == "Unit"
+ if (!mods.isPublic || isInternal) return
+
+ lazy val extractor = "scala.runtime.ScalaRunTime.stringOf(%s)".format(req fullPath vname)
+
+ // if this is a lazy val we avoid evaluating it here
+ val resultString = if (isLazy) codegenln(false, "<lazy>") else extractor
+ val codeToPrint =
+ """ + "%s: %s = " + %s""".format(prettyName, string2code(req typeOf vname), resultString)
+
+ code print codeToPrint
+ }
+ }
+
+ private class DefHandler(defDef: DefDef) extends MemberHandler(defDef) {
+ lazy val DefDef(mods, name, _, vparamss, _, _) = defDef
+ override lazy val boundNames = List(name)
+ // true if 0-arity
+ override def generatesValue =
+ if (vparamss.isEmpty || vparamss.head.isEmpty) Some(name)
+ else None
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ if (mods.isPublic) code print codegenln(name, ": ", req.typeOf(name))
+ }
+
+ private class AssignHandler(member: Assign) extends MemberHandler(member) {
+ val lhs = member.lhs.asInstanceOf[Ident] // an unfortunate limitation
+ val helperName = newTermName(synthVarNameCreator())
+ override def generatesValue = Some(helperName)
+
+ override def extraCodeToEvaluate(req: Request, code: PrintWriter) =
+ code println """val %s = %s""".format(helperName, lhs)
+
+ /** Print out lhs instead of the generated varName */
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ val lhsType = string2code(req typeOfEnc helperName)
+ val res = string2code(req fullPath helperName)
+ val codeToPrint = """ + "%s: %s = " + %s + "\n" """.format(lhs, lhsType, res)
+
+ code println codeToPrint
+ }
+ }
+
+ private class ModuleHandler(module: ModuleDef) extends MemberHandler(module) {
+ lazy val ModuleDef(mods, name, _) = module
+ override lazy val boundNames = List(name)
+ override def generatesValue = Some(name)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln("defined module ", name)
+ }
+
+ private class ClassHandler(classdef: ClassDef) extends MemberHandler(classdef) {
+ lazy val ClassDef(mods, name, _, _) = classdef
+ override lazy val boundNames =
+ name :: (if (mods hasFlag Flags.CASE) List(name.toTermName) else Nil)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code print codegenln("defined %s %s".format(classdef.keyword, name))
+ }
+
+ private class TypeAliasHandler(typeDef: TypeDef) extends MemberHandler(typeDef) {
+ lazy val TypeDef(mods, name, _, _) = typeDef
+ def isAlias() = mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef)
+ override lazy val boundNames = if (isAlias) List(name) else Nil
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln("defined type alias ", name)
+ }
+
+ private class ImportHandler(imp: Import) extends MemberHandler(imp) {
+ lazy val Import(expr, selectors) = imp
+ def targetType = stringToCompilerType(expr.toString) match {
+ case NoType => None
+ case x => Some(x)
+ }
+
+ private def selectorWild = selectors filter (_.name == USCOREkw) // wildcard imports, e.g. import foo._
+ private def selectorMasked = selectors filter (_.rename == USCOREkw) // masking imports, e.g. import foo.{ bar => _ }
+ private def selectorNames = selectors map (_.name)
+ private def selectorRenames = selectors map (_.rename) filterNot (_ == null)
+
+ /** Whether this import includes a wildcard import */
+ val importsWildcard = selectorWild.nonEmpty
+
+ /** Complete list of names imported by a wildcard */
+ def wildcardImportedNames: List[Name] = (
+ for (tpe <- targetType ; if importsWildcard) yield
+ tpe.nonPrivateMembers filter (x => x.isMethod && x.isPublic) map (_.name) distinct
+ ).toList.flatten
+
+ /** The individual names imported by this statement */
+ /** XXX come back to this and see what can be done with wildcards now that
+ * we know how to enumerate the identifiers.
+ */
+ val importedNames: List[Name] =
+ selectorRenames filterNot (_ == USCOREkw) flatMap (x => List(x.toTypeName, x.toTermName))
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln(imp.toString)
+ }
+
+ /** One line of code submitted by the user for interpretation */
+ private class Request(val line: String, val lineName: String, val trees: List[Tree]) {
+ /** name to use for the object that will compute "line" */
+ def objectName = lineName + INTERPRETER_WRAPPER_SUFFIX
+
+ /** name of the object that retrieves the result from the above object */
+ def resultObjectName = "RequestResult$" + objectName
+
+ /** handlers for each tree in this request */
+ val handlers: List[MemberHandler] = trees map chooseHandler
+
+ /** all (public) names defined by these statements */
+ val boundNames = handlers flatMap (_.boundNames)
+
+ /** list of names used by this expression */
+ val usedNames: List[Name] = handlers flatMap (_.usedNames)
+
+ /** def and val names */
+ def defNames = partialFlatMap(handlers) { case x: DefHandler => x.boundNames }
+ def valueNames = partialFlatMap(handlers) {
+ case x: AssignHandler => List(x.helperName)
+ case x: ValHandler => boundNames
+ case x: ModuleHandler => List(x.name)
+ }
+
+ /** Code to import bound names from previous lines - accessPath is code to
+ * append to objectName to access anything bound by request.
+ */
+ val ComputedImports(importsPreamble, importsTrailer, accessPath) =
+ importsCode(Set.empty ++ usedNames)
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: String): String = "%s.`%s`".format(objectName + ".INSTANCE" + accessPath, vname)
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: Name): String = fullPath(vname.toString)
+
+ /** the line of code to compute */
+ def toCompute = line
+
+ /** generate the source code for the object that computes this request */
+ def objectSourceCode: String = stringFromWriter { code =>
+ val preamble = """
+ |@serializable class %s {
+ | %s%s
+ """.stripMargin.format(objectName, importsPreamble, indentCode(toCompute))
+ val postamble = importsTrailer + "\n}"
+
+ code println preamble
+ handlers foreach { _.extraCodeToEvaluate(this, code) }
+ code println postamble
+
+ //create an object
+ code.println("object " + objectName + " {")
+ code.println(" val INSTANCE = new " + objectName + "();")
+ code.println("}")
+ }
+
+ /** generate source code for the object that retrieves the result
+ from objectSourceCode */
+ def resultObjectSourceCode: String = stringFromWriter { code =>
+ /** We only want to generate this code when the result
+ * is a value which can be referred to as-is.
+ */
+ val valueExtractor = handlers.last.generatesValue match {
+ case Some(vname) if typeOf contains vname =>
+ """
+ |lazy val scala_repl_value = {
+ | scala_repl_result
+ | %s
+ |}""".stripMargin.format(fullPath(vname))
+ case _ => ""
+ }
+
+ // first line evaluates object to make sure constructor is run
+ // initial "" so later code can uniformly be: + etc
+ val preamble = """
+ |object %s {
+ | %s
+ | val scala_repl_result: String = {
+ | %s
+ | (""
+ """.stripMargin.format(resultObjectName, valueExtractor, objectName + ".INSTANCE" + accessPath)
+
+ val postamble = """
+ | )
+ | }
+ |}
+ """.stripMargin
+
+ code println preamble
+ if (printResults) {
+ handlers foreach { _.resultExtractionCode(this, code) }
+ }
+ code println postamble
+ }
+
+ // compile the object containing the user's code
+ lazy val objRun = compileAndSaveRun("<console>", objectSourceCode)
+
+ // compile the result-extraction object
+ lazy val extractionObjectRun = compileAndSaveRun("<console>", resultObjectSourceCode)
+
+ lazy val loadedResultObject = loadByName(resultObjectName)
+
+ def extractionValue(): Option[AnyRef] = {
+ // ensure it has run
+ extractionObjectRun
+
+ // load it and retrieve the value
+ try Some(loadedResultObject getMethod "scala_repl_value" invoke loadedResultObject)
+ catch { case _: Exception => None }
+ }
+
+ /** Compile the object file. Returns whether the compilation succeeded.
+ * If all goes well, the "types" map is computed. */
+ def compile(): Boolean = {
+ // error counting is wrong, hence interpreter may overlook failure - so we reset
+ reporter.reset
+
+ // compile the main object
+ objRun
+
+ // bail on error
+ if (reporter.hasErrors)
+ return false
+
+ // extract and remember types
+ typeOf
+
+ // compile the result-extraction object
+ extractionObjectRun
+
+ // success
+ !reporter.hasErrors
+ }
+
+ def atNextPhase[T](op: => T): T = compiler.atPhase(objRun.typerPhase.next)(op)
+
+ /** The outermost wrapper object */
+ lazy val outerResObjSym: Symbol = getMember(EmptyPackage, newTermName(objectName).toTypeName)
+
+ /** The innermost object inside the wrapper, found by
+ * following accessPath into the outer one. */
+ lazy val resObjSym =
+ accessPath.split("\\.").foldLeft(outerResObjSym) { (sym, name) =>
+ if (name == "") sym else
+ atNextPhase(sym.info member newTermName(name))
+ }
+
+ /* typeOf lookup with encoding */
+ def typeOfEnc(vname: Name) = typeOf(compiler encode vname)
+
+ /** Types of variables defined by this request. */
+ lazy val typeOf: Map[Name, String] = {
+ def getTypes(names: List[Name], nameMap: Name => Name): Map[Name, String] = {
+ names.foldLeft(Map.empty[Name, String]) { (map, name) =>
+ val rawType = atNextPhase(resObjSym.info.member(name).tpe)
+ // the types are all =>T; remove the =>
+ val cleanedType = rawType match {
+ case compiler.PolyType(Nil, rt) => rt
+ case rawType => rawType
+ }
+
+ map + (name -> atNextPhase(cleanedType.toString))
+ }
+ }
+
+ getTypes(valueNames, nme.getterToLocal(_)) ++ getTypes(defNames, identity)
+ }
+
+ /** load and run the code using reflection */
+ def loadAndRun: (String, Boolean) = {
+ val resultValMethod: reflect.Method = loadedResultObject getMethod "scala_repl_result"
+ // XXX if wrapperExceptions isn't type-annotated we crash scalac
+ val wrapperExceptions: List[Class[_ <: Throwable]] =
+ List(classOf[InvocationTargetException], classOf[ExceptionInInitializerError])
+
+ /** We turn off the binding to accomodate ticket #2817 */
+ def onErr: Catcher[(String, Boolean)] = {
+ case t: Throwable if bindLastException =>
+ withoutBindingLastException {
+ quietBind("lastException", "java.lang.Throwable", t)
+ (stringFromWriter(t.printStackTrace(_)), false)
+ }
+ }
+
+ catching(onErr) {
+ unwrapping(wrapperExceptions: _*) {
+ (resultValMethod.invoke(loadedResultObject).toString, true)
+ }
+ }
+ }
+
+ override def toString = "Request(line=%s, %s trees)".format(line, trees.size)
+ }
+
+ /** A container class for methods to be injected into the repl
+ * in power mode.
+ */
+ object power {
+ lazy val compiler: repl.compiler.type = repl.compiler
+ import compiler.{ phaseNames, atPhase, currentRun }
+
+ def mkContext(code: String = "") = compiler.analyzer.rootContext(mkUnit(code))
+ def mkAlias(name: String, what: String) = interpret("type %s = %s".format(name, what))
+ def mkSourceFile(code: String) = new BatchSourceFile("<console>", code)
+ def mkUnit(code: String) = new CompilationUnit(mkSourceFile(code))
+
+ def mkTree(code: String): Tree = mkTrees(code).headOption getOrElse EmptyTree
+ def mkTrees(code: String): List[Tree] = parse(code) getOrElse Nil
+ def mkTypedTrees(code: String*): List[compiler.Tree] = {
+ class TyperRun extends compiler.Run {
+ override def stopPhase(name: String) = name == "superaccessors"
+ }
+
+ reporter.reset
+ val run = new TyperRun
+ run compileSources (code.toList.zipWithIndex map {
+ case (s, i) => new BatchSourceFile("<console %d>".format(i), s)
+ })
+ run.units.toList map (_.body)
+ }
+ def mkTypedTree(code: String) = mkTypedTrees(code).head
+ def mkType(id: String): compiler.Type = stringToCompilerType(id)
+
+ def dump(): String = (
+ ("Names used: " :: allUsedNames) ++
+ ("\nIdentifiers: " :: unqualifiedIds)
+ ) mkString " "
+
+ lazy val allPhases: List[Phase] = phaseNames map (currentRun phaseNamed _)
+ def atAllPhases[T](op: => T): List[(String, T)] = allPhases map (ph => (ph.name, atPhase(ph)(op)))
+ def showAtAllPhases(op: => Any): Unit =
+ atAllPhases(op.toString) foreach { case (ph, op) => Console.println("%15s -> %s".format(ph, op take 240)) }
+ }
+
+ def unleash(): Unit = beQuietDuring {
+ interpret("import scala.tools.nsc._")
+ repl.bind("repl", "spark.repl.SparkInterpreter", this)
+ interpret("val global: repl.compiler.type = repl.compiler")
+ interpret("val power: repl.power.type = repl.power")
+ // interpret("val replVars = repl.replVars")
+ }
+
+ /** Artificial object demonstrating completion */
+ // lazy val replVars = CompletionAware(
+ // Map[String, CompletionAware](
+ // "ids" -> CompletionAware(() => unqualifiedIds, completionAware _),
+ // "synthVars" -> CompletionAware(() => allBoundNames filter isSynthVarName map (_.toString)),
+ // "types" -> CompletionAware(() => allSeenTypes map (_.toString)),
+ // "implicits" -> CompletionAware(() => allImplicits map (_.toString))
+ // )
+ // )
+
+ /** Returns the name of the most recent interpreter result.
+ * Mostly this exists so you can conveniently invoke methods on
+ * the previous result.
+ */
+ def mostRecentVar: String =
+ if (mostRecentlyHandledTree.isEmpty) ""
+ else mostRecentlyHandledTree.get match {
+ case x: ValOrDefDef => x.name
+ case Assign(Ident(name), _) => name
+ case ModuleDef(_, name, _) => name
+ case _ => onull(varNameCreator.mostRecent)
+ }
+
+ private def requestForName(name: Name): Option[Request] =
+ prevRequests.reverse find (_.boundNames contains name)
+
+ private def requestForIdent(line: String): Option[Request] = requestForName(newTermName(line))
+
+ def stringToCompilerType(id: String): compiler.Type = {
+ // if it's a recognized identifier, the type of that; otherwise treat the
+ // String like a value (e.g. scala.collection.Map) .
+ def findType = typeForIdent(id) match {
+ case Some(x) => definitions.getClass(newTermName(x)).tpe
+ case _ => definitions.getModule(newTermName(id)).tpe
+ }
+
+ try findType catch { case _: MissingRequirementError => NoType }
+ }
+
+ def typeForIdent(id: String): Option[String] =
+ requestForIdent(id) flatMap (x => x.typeOf get newTermName(id))
+
+ def methodsOf(name: String) =
+ evalExpr[List[String]](methodsCode(name)) map (x => NameTransformer.decode(getOriginalName(x)))
+
+ def completionAware(name: String) = {
+ // XXX working around "object is not a value" crash, i.e.
+ // import java.util.ArrayList ; ArrayList.<tab>
+ clazzForIdent(name) flatMap (_ => evalExpr[Option[CompletionAware]](asCompletionAwareCode(name)))
+ }
+
+ def extractionValueForIdent(id: String): Option[AnyRef] =
+ requestForIdent(id) flatMap (_.extractionValue)
+
+ /** Executes code looking for a manifest of type T.
+ */
+ def manifestFor[T: Manifest] =
+ evalExpr[Manifest[T]]("""manifest[%s]""".format(manifest[T]))
+
+ /** Executes code looking for an implicit value of type T.
+ */
+ def implicitFor[T: Manifest] = {
+ val s = manifest[T].toString
+ evalExpr[Option[T]]("{ def f(implicit x: %s = null): %s = x ; Option(f) }".format(s, s))
+ // We don't use implicitly so as to fail without failing.
+ // evalExpr[T]("""implicitly[%s]""".format(manifest[T]))
+ }
+ /** Executes code looking for an implicit conversion from the type
+ * of the given identifier to CompletionAware.
+ */
+ def completionAwareImplicit[T](id: String) = {
+ val f1string = "%s => %s".format(typeForIdent(id).get, classOf[CompletionAware].getName)
+ val code = """{
+ | def f(implicit x: (%s) = null): %s = x
+ | val f1 = f
+ | if (f1 == null) None else Some(f1(%s))
+ |}""".stripMargin.format(f1string, f1string, id)
+
+ evalExpr[Option[CompletionAware]](code)
+ }
+
+ def clazzForIdent(id: String): Option[Class[_]] =
+ extractionValueForIdent(id) flatMap (x => Option(x) map (_.getClass))
+
+ private def methodsCode(name: String) =
+ "%s.%s(%s)".format(classOf[ReflectionCompletion].getName, "methodsOf", name)
+
+ private def asCompletionAwareCode(name: String) =
+ "%s.%s(%s)".format(classOf[CompletionAware].getName, "unapply", name)
+
+ private def getOriginalName(name: String): String =
+ nme.originalName(newTermName(name)).toString
+
+ case class InterpreterEvalException(msg: String) extends Exception(msg)
+ def evalError(msg: String) = throw InterpreterEvalException(msg)
+
+ /** The user-facing eval in :power mode wraps an Option.
+ */
+ def eval[T: Manifest](line: String): Option[T] =
+ try Some(evalExpr[T](line))
+ catch { case InterpreterEvalException(msg) => out println indentString(msg) ; None }
+
+ def evalExpr[T: Manifest](line: String): T = {
+ // Nothing means the type could not be inferred.
+ if (manifest[T] eq Manifest.Nothing)
+ evalError("Could not infer type: try 'eval[SomeType](%s)' instead".format(line))
+
+ val lhs = getSynthVarName
+ beQuietDuring { interpret("val " + lhs + " = { " + line + " } ") }
+
+ // TODO - can we meaningfully compare the inferred type T with
+ // the internal compiler Type assigned to lhs?
+ // def assignedType = prevRequests.last.typeOf(newTermName(lhs))
+
+ val req = requestFromLine(lhs, true) match {
+ case Left(result) => evalError(result.toString)
+ case Right(req) => req
+ }
+ if (req == null || !req.compile || req.handlers.size != 1)
+ evalError("Eval error.")
+
+ try req.extractionValue.get.asInstanceOf[T] catch {
+ case e: Exception => evalError(e.getMessage)
+ }
+ }
+
+ def interpretExpr[T: Manifest](code: String): Option[T] = beQuietDuring {
+ interpret(code) match {
+ case IR.Success =>
+ try prevRequests.last.extractionValue map (_.asInstanceOf[T])
+ catch { case e: Exception => out println e ; None }
+ case _ => None
+ }
+ }
+
+ /** Another entry point for tab-completion, ids in scope */
+ private def unqualifiedIdNames() = partialFlatMap(allHandlers) {
+ case x: AssignHandler => List(x.helperName)
+ case x: ValHandler => List(x.vname)
+ case x: ModuleHandler => List(x.name)
+ case x: DefHandler => List(x.name)
+ case x: ImportHandler => x.importedNames
+ } filterNot isSynthVarName
+
+ /** Types which have been wildcard imported, such as:
+ * val x = "abc" ; import x._ // type java.lang.String
+ * import java.lang.String._ // object java.lang.String
+ *
+ * Used by tab completion.
+ *
+ * XXX right now this gets import x._ and import java.lang.String._,
+ * but doesn't figure out import String._. There's a lot of ad hoc
+ * scope twiddling which should be swept away in favor of digging
+ * into the compiler scopes.
+ */
+ def wildcardImportedTypes(): List[Type] = {
+ val xs = allHandlers collect { case x: ImportHandler if x.importsWildcard => x.targetType }
+ xs.flatten.reverse.distinct
+ }
+
+ /** Another entry point for tab-completion, ids in scope */
+ def unqualifiedIds() = (unqualifiedIdNames() map (_.toString)).distinct.sorted
+
+ /** For static/object method completion */
+ def getClassObject(path: String): Option[Class[_]] = //classLoader tryToLoadClass path
+ try {
+ Some(Class.forName(path, true, classLoader))
+ } catch {
+ case e: Exception => None
+ }
+
+ /** Parse the ScalaSig to find type aliases */
+ def aliasForType(path: String) = ByteCode.aliasForType(path)
+
+ // Coming soon
+ // implicit def string2liftedcode(s: String): LiftedCode = new LiftedCode(s)
+ // case class LiftedCode(code: String) {
+ // val lifted: String = {
+ // beQuietDuring { interpret(code) }
+ // eval2[String]("({ " + code + " }).toString")
+ // }
+ // def >> : String = lifted
+ // }
+
+ // debugging
+ def isReplDebug = settings.Yrepldebug.value
+ def isCompletionDebug = settings.Ycompletion.value
+ def DBG(s: String) = if (isReplDebug) out println s else ()
+}
+
+/** Utility methods for the Interpreter. */
+object SparkInterpreter {
+
+ import scala.collection.generic.CanBuildFrom
+ def partialFlatMap[A, B, CC[X] <: Traversable[X]]
+ (coll: CC[A])
+ (pf: PartialFunction[A, CC[B]])
+ (implicit bf: CanBuildFrom[CC[A], B, CC[B]]) =
+ {
+ val b = bf(coll)
+ for (x <- coll collect pf)
+ b ++= x
+
+ b.result
+ }
+
+ object DebugParam {
+ implicit def tuple2debugparam[T](x: (String, T))(implicit m: Manifest[T]): DebugParam[T] =
+ DebugParam(x._1, x._2)
+
+ implicit def any2debugparam[T](x: T)(implicit m: Manifest[T]): DebugParam[T] =
+ DebugParam("p" + getCount(), x)
+
+ private var counter = 0
+ def getCount() = { counter += 1; counter }
+ }
+ case class DebugParam[T](name: String, param: T)(implicit m: Manifest[T]) {
+ val manifest = m
+ val typeStr = {
+ val str = manifest.toString
+ // I'm sure there are more to be discovered...
+ val regexp1 = """(.*?)\[(.*)\]""".r
+ val regexp2str = """.*\.type#"""
+ val regexp2 = (regexp2str + """(.*)""").r
+
+ (str.replaceAll("""\n""", "")) match {
+ case regexp1(clazz, typeArgs) => "%s[%s]".format(clazz, typeArgs.replaceAll(regexp2str, ""))
+ case regexp2(clazz) => clazz
+ case _ => str
+ }
+ }
+ }
+ def breakIf(assertion: => Boolean, args: DebugParam[_]*): Unit =
+ if (assertion) break(args.toList)
+
+ // start a repl, binding supplied args
+ def break(args: List[DebugParam[_]]): Unit = {
+ val intLoop = new SparkInterpreterLoop
+ intLoop.settings = new Settings(Console.println)
+ // XXX come back to the dot handling
+ intLoop.settings.classpath.value = "."
+ intLoop.createInterpreter
+ intLoop.in = SparkInteractiveReader.createDefault(intLoop.interpreter)
+
+ // rebind exit so people don't accidentally call System.exit by way of predef
+ intLoop.interpreter.beQuietDuring {
+ intLoop.interpreter.interpret("""def exit = println("Type :quit to resume program execution.")""")
+ for (p <- args) {
+ intLoop.interpreter.bind(p.name, p.typeStr, p.param)
+ Console println "%s: %s".format(p.name, p.typeStr)
+ }
+ }
+ intLoop.repl()
+ intLoop.closeInterpreter
+ }
+
+ def codegenln(leadingPlus: Boolean, xs: String*): String = codegen(leadingPlus, (xs ++ Array("\n")): _*)
+ def codegenln(xs: String*): String = codegenln(true, xs: _*)
+
+ def codegen(xs: String*): String = codegen(true, xs: _*)
+ def codegen(leadingPlus: Boolean, xs: String*): String = {
+ val front = if (leadingPlus) "+ " else ""
+ front + (xs map string2codeQuoted mkString " + ")
+ }
+
+ def string2codeQuoted(str: String) = "\"" + string2code(str) + "\""
+
+ /** Convert a string into code that can recreate the string.
+ * This requires replacing all special characters by escape
+ * codes. It does not add the surrounding " marks. */
+ def string2code(str: String): String = {
+ val res = new StringBuilder
+ for (c <- str) c match {
+ case '"' | '\'' | '\\' => res += '\\' ; res += c
+ case _ if c.isControl => res ++= Chars.char2uescape(c)
+ case _ => res += c
+ }
+ res.toString
+ }
+}
+
diff --git a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
new file mode 100644
index 0000000000..d4974009ce
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
@@ -0,0 +1,659 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import Predef.{ println => _, _ }
+import java.io.{ BufferedReader, FileReader, PrintWriter }
+import java.io.IOException
+
+import scala.tools.nsc.{ InterpreterResults => IR }
+import scala.annotation.tailrec
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.ops
+import util.{ ClassPath }
+import interpreter._
+import io.{ File, Process }
+
+import spark.SparkContext
+
+// Classes to wrap up interpreter commands and their results
+// You can add new commands by adding entries to val commands
+// inside InterpreterLoop.
+trait InterpreterControl {
+ self: SparkInterpreterLoop =>
+
+ // the default result means "keep running, and don't record that line"
+ val defaultResult = Result(true, None)
+
+ // a single interpreter command
+ sealed abstract class Command extends Function1[List[String], Result] {
+ def name: String
+ def help: String
+ def error(msg: String) = {
+ out.println(":" + name + " " + msg + ".")
+ Result(true, None)
+ }
+ def usage(): String
+ }
+
+ case class NoArgs(name: String, help: String, f: () => Result) extends Command {
+ def usage(): String = ":" + name
+ def apply(args: List[String]) = if (args.isEmpty) f() else error("accepts no arguments")
+ }
+
+ case class LineArg(name: String, help: String, f: (String) => Result) extends Command {
+ def usage(): String = ":" + name + " <line>"
+ def apply(args: List[String]) = f(args mkString " ")
+ }
+
+ case class OneArg(name: String, help: String, f: (String) => Result) extends Command {
+ def usage(): String = ":" + name + " <arg>"
+ def apply(args: List[String]) =
+ if (args.size == 1) f(args.head)
+ else error("requires exactly one argument")
+ }
+
+ case class VarArgs(name: String, help: String, f: (List[String]) => Result) extends Command {
+ def usage(): String = ":" + name + " [arg]"
+ def apply(args: List[String]) = f(args)
+ }
+
+ // the result of a single command
+ case class Result(keepRunning: Boolean, lineToRecord: Option[String])
+}
+
+/** The
+ * <a href="http://scala-lang.org/" target="_top">Scala</a>
+ * interactive shell. It provides a read-eval-print loop around
+ * the Interpreter class.
+ * After instantiation, clients should call the <code>main()</code> method.
+ *
+ * <p>If no in0 is specified, then input will come from the console, and
+ * the class will attempt to provide input editing feature such as
+ * input history.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ * @version 1.2
+ */
+class SparkInterpreterLoop(
+ in0: Option[BufferedReader], val out: PrintWriter, master: Option[String])
+extends InterpreterControl {
+ def this(in0: BufferedReader, out: PrintWriter, master: String) =
+ this(Some(in0), out, Some(master))
+
+ def this(in0: BufferedReader, out: PrintWriter) =
+ this(Some(in0), out, None)
+
+ def this() = this(None, new PrintWriter(Console.out), None)
+
+ /** The input stream from which commands come, set by main() */
+ var in: SparkInteractiveReader = _
+
+ /** The context class loader at the time this object was created */
+ protected val originalClassLoader = Thread.currentThread.getContextClassLoader
+
+ var settings: Settings = _ // set by main()
+ var interpreter: SparkInterpreter = _ // set by createInterpreter()
+
+ // classpath entries added via :cp
+ var addedClasspath: String = ""
+
+ /** A reverse list of commands to replay if the user requests a :replay */
+ var replayCommandStack: List[String] = Nil
+
+ /** A list of commands to replay if the user requests a :replay */
+ def replayCommands = replayCommandStack.reverse
+
+ /** Record a command for replay should the user request a :replay */
+ def addReplay(cmd: String) = replayCommandStack ::= cmd
+
+ /** Close the interpreter and set the var to <code>null</code>. */
+ def closeInterpreter() {
+ if (interpreter ne null) {
+ interpreter.close
+ interpreter = null
+ Thread.currentThread.setContextClassLoader(originalClassLoader)
+ }
+ }
+
+ /** Create a new interpreter. */
+ def createInterpreter() {
+ if (addedClasspath != "")
+ settings.classpath append addedClasspath
+
+ interpreter = new SparkInterpreter(settings, out) {
+ override protected def parentClassLoader =
+ classOf[SparkInterpreterLoop].getClassLoader
+ }
+ interpreter.setContextClassLoader()
+ // interpreter.quietBind("settings", "spark.repl.SparkInterpreterSettings", interpreter.isettings)
+ }
+
+ /** print a friendly help message */
+ def printHelp() = {
+ out println "All commands can be abbreviated - for example :he instead of :help.\n"
+ val cmds = commands map (x => (x.usage, x.help))
+ val width: Int = cmds map { case (x, _) => x.length } max
+ val formatStr = "%-" + width + "s %s"
+ cmds foreach { case (usage, help) => out println formatStr.format(usage, help) }
+ }
+
+ /** Print a welcome message */
+ def printWelcome() {
+ plushln("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version 0.0
+ /_/
+""")
+
+ import Properties._
+ val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
+ versionString, javaVmName, javaVersion)
+ plushln(welcomeMsg)
+ }
+
+ /** Show the history */
+ def printHistory(xs: List[String]) {
+ val defaultLines = 20
+
+ if (in.history.isEmpty)
+ return println("No history available.")
+
+ val current = in.history.get.index
+ val count = try xs.head.toInt catch { case _: Exception => defaultLines }
+ val lines = in.historyList takeRight count
+ val offset = current - lines.size + 1
+
+ for ((line, index) <- lines.zipWithIndex)
+ println("%d %s".format(index + offset, line))
+ }
+
+ /** Some print conveniences */
+ def println(x: Any) = out println x
+ def plush(x: Any) = { out print x ; out.flush() }
+ def plushln(x: Any) = { out println x ; out.flush() }
+
+ /** Search the history */
+ def searchHistory(_cmdline: String) {
+ val cmdline = _cmdline.toLowerCase
+
+ if (in.history.isEmpty)
+ return println("No history available.")
+
+ val current = in.history.get.index
+ val offset = current - in.historyList.size + 1
+
+ for ((line, index) <- in.historyList.zipWithIndex ; if line.toLowerCase contains cmdline)
+ println("%d %s".format(index + offset, line))
+ }
+
+ /** Prompt to print when awaiting input */
+ val prompt = Properties.shellPromptString
+
+ // most commands do not want to micromanage the Result, but they might want
+ // to print something to the console, so we accomodate Unit and String returns.
+ object CommandImplicits {
+ implicit def u2ir(x: Unit): Result = defaultResult
+ implicit def s2ir(s: String): Result = {
+ out println s
+ defaultResult
+ }
+ }
+
+ /** Standard commands **/
+ val standardCommands: List[Command] = {
+ import CommandImplicits._
+ List(
+ OneArg("cp", "add an entry (jar or directory) to the classpath", addClasspath),
+ NoArgs("help", "print this help message", printHelp),
+ VarArgs("history", "show the history (optional arg: lines to show)", printHistory),
+ LineArg("h?", "search the history", searchHistory),
+ OneArg("load", "load and interpret a Scala file", load),
+ NoArgs("power", "enable power user mode", power),
+ NoArgs("quit", "exit the interpreter", () => Result(false, None)),
+ NoArgs("replay", "reset execution and replay all previous commands", replay),
+ LineArg("sh", "fork a shell and run a command", runShellCmd),
+ NoArgs("silent", "disable/enable automatic printing of results", verbosity)
+ )
+ }
+
+ /** Power user commands */
+ var powerUserOn = false
+ val powerCommands: List[Command] = {
+ import CommandImplicits._
+ List(
+ OneArg("completions", "generate list of completions for a given String", completions),
+ NoArgs("dump", "displays a view of the interpreter's internal state", () => interpreter.power.dump())
+
+ // VarArgs("tree", "displays ASTs for specified identifiers",
+ // (xs: List[String]) => interpreter dumpTrees xs)
+ // LineArg("meta", "given code which produces scala code, executes the results",
+ // (xs: List[String]) => )
+ )
+ }
+
+ /** Available commands */
+ def commands: List[Command] = standardCommands ::: (if (powerUserOn) powerCommands else Nil)
+
+ def initializeSpark() {
+ interpreter.beQuietDuring {
+ command("""
+ spark.repl.Main.interp.out.println("Registering with Mesos...");
+ spark.repl.Main.interp.out.flush();
+ @transient val sc = spark.repl.Main.interp.createSparkContext();
+ sc.waitForRegister();
+ spark.repl.Main.interp.out.println("Spark context available as sc.");
+ spark.repl.Main.interp.out.flush();
+ """)
+ command("import spark.SparkContext._");
+ }
+ plushln("Type in expressions to have them evaluated.")
+ plushln("Type :help for more information.")
+ }
+
+ def createSparkContext(): SparkContext = {
+ val master = this.master match {
+ case Some(m) => m
+ case None => {
+ val prop = System.getenv("MASTER")
+ if (prop != null) prop else "local"
+ }
+ }
+ new SparkContext(master, "Spark shell")
+ }
+
+ /** The main read-eval-print loop for the interpreter. It calls
+ * <code>command()</code> for each line of input, and stops when
+ * <code>command()</code> returns <code>false</code>.
+ */
+ def repl() {
+ def readOneLine() = {
+ out.flush
+ in readLine prompt
+ }
+ // return false if repl should exit
+ def processLine(line: String): Boolean =
+ if (line eq null) false // assume null means EOF
+ else command(line) match {
+ case Result(false, _) => false
+ case Result(_, Some(finalLine)) => addReplay(finalLine) ; true
+ case _ => true
+ }
+
+ while (processLine(readOneLine)) { }
+ }
+
+ /** interpret all lines from a specified file */
+ def interpretAllFrom(file: File) {
+ val oldIn = in
+ val oldReplay = replayCommandStack
+
+ try file applyReader { reader =>
+ in = new SparkSimpleReader(reader, out, false)
+ plushln("Loading " + file + "...")
+ repl()
+ }
+ finally {
+ in = oldIn
+ replayCommandStack = oldReplay
+ }
+ }
+
+ /** create a new interpreter and replay all commands so far */
+ def replay() {
+ closeInterpreter()
+ createInterpreter()
+ for (cmd <- replayCommands) {
+ plushln("Replaying: " + cmd) // flush because maybe cmd will have its own output
+ command(cmd)
+ out.println
+ }
+ }
+
+ /** fork a shell and run a command */
+ def runShellCmd(line: String) {
+ // we assume if they're using :sh they'd appreciate being able to pipeline
+ interpreter.beQuietDuring {
+ interpreter.interpret("import _root_.scala.tools.nsc.io.Process.Pipe._")
+ }
+ val p = Process(line)
+ // only bind non-empty streams
+ def add(name: String, it: Iterator[String]) =
+ if (it.hasNext) interpreter.bind(name, "scala.List[String]", it.toList)
+
+ List(("stdout", p.stdout), ("stderr", p.stderr)) foreach (add _).tupled
+ }
+
+ def withFile(filename: String)(action: File => Unit) {
+ val f = File(filename)
+
+ if (f.exists) action(f)
+ else out.println("That file does not exist")
+ }
+
+ def load(arg: String) = {
+ var shouldReplay: Option[String] = None
+ withFile(arg)(f => {
+ interpretAllFrom(f)
+ shouldReplay = Some(":load " + arg)
+ })
+ Result(true, shouldReplay)
+ }
+
+ def addClasspath(arg: String): Unit = {
+ val f = File(arg).normalize
+ if (f.exists) {
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ println("Added '%s'. Your new classpath is:\n%s".format(f.path, totalClasspath))
+ replay()
+ }
+ else out.println("The path '" + f + "' doesn't seem to exist.")
+ }
+
+ def completions(arg: String): Unit = {
+ val comp = in.completion getOrElse { return println("Completion unavailable.") }
+ val xs = comp completions arg
+
+ injectAndName(xs)
+ }
+
+ def power() {
+ val powerUserBanner =
+ """** Power User mode enabled - BEEP BOOP **
+ |** scala.tools.nsc._ has been imported **
+ |** New vals! Try repl, global, power **
+ |** New cmds! :help to discover them **
+ |** New defs! Type power.<tab> to reveal **""".stripMargin
+
+ powerUserOn = true
+ interpreter.unleash()
+ injectOne("history", in.historyList)
+ in.completion foreach (x => injectOne("completion", x))
+ out println powerUserBanner
+ }
+
+ def verbosity() = {
+ val old = interpreter.printResults
+ interpreter.printResults = !old
+ out.println("Switched " + (if (old) "off" else "on") + " result printing.")
+ }
+
+ /** Run one command submitted by the user. Two values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any. */
+ def command(line: String): Result = {
+ def withError(msg: String) = {
+ out println msg
+ Result(true, None)
+ }
+ def ambiguous(cmds: List[Command]) = "Ambiguous: did you mean " + cmds.map(":" + _.name).mkString(" or ") + "?"
+
+ // not a command
+ if (!line.startsWith(":")) {
+ // Notice failure to create compiler
+ if (interpreter.compiler == null) return Result(false, None)
+ else return Result(true, interpretStartingWith(line))
+ }
+
+ val tokens = (line drop 1 split """\s+""").toList
+ if (tokens.isEmpty)
+ return withError(ambiguous(commands))
+
+ val (cmd :: args) = tokens
+
+ // this lets us add commands willy-nilly and only requires enough command to disambiguate
+ commands.filter(_.name startsWith cmd) match {
+ case List(x) => x(args)
+ case Nil => withError("Unknown command. Type :help for help.")
+ case xs => withError(ambiguous(xs))
+ }
+ }
+
+ private val CONTINUATION_STRING = " | "
+ private val PROMPT_STRING = "scala> "
+
+ /** If it looks like they're pasting in a scala interpreter
+ * transcript, remove all the formatting we inserted so we
+ * can make some sense of it.
+ */
+ private var pasteStamp: Long = 0
+
+ /** Returns true if it's long enough to quit. */
+ def updatePasteStamp(): Boolean = {
+ /* Enough milliseconds between readLines to call it a day. */
+ val PASTE_FINISH = 1000
+
+ val prevStamp = pasteStamp
+ pasteStamp = System.currentTimeMillis
+
+ (pasteStamp - prevStamp > PASTE_FINISH)
+
+ }
+ /** TODO - we could look for the usage of resXX variables in the transcript.
+ * Right now backreferences to auto-named variables will break.
+ */
+
+ /** The trailing lines complication was an attempt to work around the introduction
+ * of newlines in e.g. email messages of repl sessions. It doesn't work because
+ * an unlucky newline can always leave you with a syntactically valid first line,
+ * which is executed before the next line is considered. So this doesn't actually
+ * accomplish anything, but I'm leaving it in case I decide to try harder.
+ */
+ case class PasteCommand(cmd: String, trailing: ListBuffer[String] = ListBuffer[String]())
+
+ /** Commands start on lines beginning with "scala>" and each successive
+ * line which begins with the continuation string is appended to that command.
+ * Everything else is discarded. When the end of the transcript is spotted,
+ * all the commands are replayed.
+ */
+ @tailrec private def cleanTranscript(lines: List[String], acc: List[PasteCommand]): List[PasteCommand] = lines match {
+ case Nil => acc.reverse
+ case x :: xs if x startsWith PROMPT_STRING =>
+ val first = x stripPrefix PROMPT_STRING
+ val (xs1, xs2) = xs span (_ startsWith CONTINUATION_STRING)
+ val rest = xs1 map (_ stripPrefix CONTINUATION_STRING)
+ val result = (first :: rest).mkString("", "\n", "\n")
+
+ cleanTranscript(xs2, PasteCommand(result) :: acc)
+
+ case ln :: lns =>
+ val newacc = acc match {
+ case Nil => Nil
+ case PasteCommand(cmd, trailing) :: accrest =>
+ PasteCommand(cmd, trailing :+ ln) :: accrest
+ }
+ cleanTranscript(lns, newacc)
+ }
+
+ /** The timestamp is for safety so it doesn't hang looking for the end
+ * of a transcript. Ad hoc parsing can't be too demanding. You can
+ * also use ctrl-D to start it parsing.
+ */
+ @tailrec private def interpretAsPastedTranscript(lines: List[String]) {
+ val line = in.readLine("")
+ val finished = updatePasteStamp()
+
+ if (line == null || finished || line.trim == PROMPT_STRING.trim) {
+ val xs = cleanTranscript(lines.reverse, Nil)
+ println("Replaying %d commands from interpreter transcript." format xs.size)
+ for (PasteCommand(cmd, trailing) <- xs) {
+ out.flush()
+ def runCode(code: String, extraLines: List[String]) {
+ (interpreter interpret code) match {
+ case IR.Incomplete if extraLines.nonEmpty =>
+ runCode(code + "\n" + extraLines.head, extraLines.tail)
+ case _ => ()
+ }
+ }
+ runCode(cmd, trailing.toList)
+ }
+ }
+ else
+ interpretAsPastedTranscript(line :: lines)
+ }
+
+ /** Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
+ def interpretStartingWith(code: String): Option[String] = {
+ // signal completion non-completion input has been received
+ in.completion foreach (_.resetVerbosity())
+
+ def reallyInterpret = interpreter.interpret(code) match {
+ case IR.Error => None
+ case IR.Success => Some(code)
+ case IR.Incomplete =>
+ if (in.interactive && code.endsWith("\n\n")) {
+ out.println("You typed two blank lines. Starting a new command.")
+ None
+ }
+ else in.readLine(CONTINUATION_STRING) match {
+ case null =>
+ // we know compilation is going to fail since we're at EOF and the
+ // parser thinks the input is still incomplete, but since this is
+ // a file being read non-interactively we want to fail. So we send
+ // it straight to the compiler for the nice error message.
+ interpreter.compileString(code)
+ None
+
+ case line => interpretStartingWith(code + "\n" + line)
+ }
+ }
+
+ /** Here we place ourselves between the user and the interpreter and examine
+ * the input they are ostensibly submitting. We intervene in several cases:
+ *
+ * 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
+ * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
+ * on the previous result.
+ * 3) If the Completion object's execute returns Some(_), we inject that value
+ * and avoid the interpreter, as it's likely not valid scala code.
+ */
+ if (code == "") None
+ else if (code startsWith PROMPT_STRING) {
+ updatePasteStamp()
+ interpretAsPastedTranscript(List(code))
+ None
+ }
+ else if (Completion.looksLikeInvocation(code) && interpreter.mostRecentVar != "") {
+ interpretStartingWith(interpreter.mostRecentVar + code)
+ }
+ else {
+ val result = for (comp <- in.completion ; res <- comp execute code) yield res
+ result match {
+ case Some(res) => injectAndName(res) ; None // completion took responsibility, so do not parse
+ case _ => reallyInterpret
+ }
+ }
+ }
+
+ // runs :load <file> on any files passed via -i
+ def loadFiles(settings: Settings) = settings match {
+ case settings: GenericRunnerSettings =>
+ for (filename <- settings.loadfiles.value) {
+ val cmd = ":load " + filename
+ command(cmd)
+ addReplay(cmd)
+ out.println()
+ }
+ case _ =>
+ }
+
+ def main(settings: Settings) {
+ this.settings = settings
+ createInterpreter()
+
+ // sets in to some kind of reader depending on environmental cues
+ in = in0 match {
+ case Some(in0) => new SparkSimpleReader(in0, out, true)
+ case None =>
+ // the interpreter is passed as an argument to expose tab completion info
+ if (settings.Xnojline.value || Properties.isEmacsShell) new SparkSimpleReader
+ else if (settings.noCompletion.value) SparkInteractiveReader.createDefault()
+ else SparkInteractiveReader.createDefault(interpreter)
+ }
+
+ loadFiles(settings)
+ try {
+ // it is broken on startup; go ahead and exit
+ if (interpreter.reporter.hasErrors) return
+
+ printWelcome()
+
+ // this is about the illusion of snappiness. We call initialize()
+ // which spins off a separate thread, then print the prompt and try
+ // our best to look ready. Ideally the user will spend a
+ // couple seconds saying "wow, it starts so fast!" and by the time
+ // they type a command the compiler is ready to roll.
+ interpreter.initialize()
+ initializeSpark()
+ repl()
+ }
+ finally closeInterpreter()
+ }
+
+ private def objClass(x: Any) = x.asInstanceOf[AnyRef].getClass
+ private def objName(x: Any) = {
+ val clazz = objClass(x)
+ val typeParams = clazz.getTypeParameters
+ val basename = clazz.getName
+ val tpString = if (typeParams.isEmpty) "" else "[%s]".format(typeParams map (_ => "_") mkString ", ")
+
+ basename + tpString
+ }
+
+ // injects one value into the repl; returns pair of name and class
+ def injectOne(name: String, obj: Any): Tuple2[String, String] = {
+ val className = objName(obj)
+ interpreter.quietBind(name, className, obj)
+ (name, className)
+ }
+ def injectAndName(obj: Any): Tuple2[String, String] = {
+ val name = interpreter.getVarName
+ val className = objName(obj)
+ interpreter.bind(name, className, obj)
+ (name, className)
+ }
+
+ // injects list of values into the repl; returns summary string
+ def injectDebug(args: List[Any]): String = {
+ val strs =
+ for ((arg, i) <- args.zipWithIndex) yield {
+ val varName = "p" + (i + 1)
+ val (vname, vtype) = injectOne(varName, arg)
+ vname + ": " + vtype
+ }
+
+ if (strs.size == 0) "Set no variables."
+ else "Variables set:\n" + strs.mkString("\n")
+ }
+
+ /** process command-line arguments and do as they request */
+ def main(args: Array[String]) {
+ def error1(msg: String) = out println ("scala: " + msg)
+ val command = new InterpreterCommand(args.toList, error1)
+ def neededHelp(): String =
+ (if (command.settings.help.value) command.usageMsg + "\n" else "") +
+ (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
+
+ // if they asked for no help and command is valid, we call the real main
+ neededHelp() match {
+ case "" => if (command.ok) main(command.settings) // else nothing
+ case help => plush(help)
+ }
+ }
+}
+
diff --git a/core/src/main/scala/spark/repl/SparkInterpreterSettings.scala b/core/src/main/scala/spark/repl/SparkInterpreterSettings.scala
new file mode 100644
index 0000000000..ffa477785b
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkInterpreterSettings.scala
@@ -0,0 +1,112 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+/** Settings for the interpreter
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/3/24
+ **/
+class SparkInterpreterSettings(repl: SparkInterpreter) {
+ /** A list of paths where :load should look */
+ var loadPath = List(".")
+
+ /** The maximum length of toString to use when printing the result
+ * of an evaluation. 0 means no maximum. If a printout requires
+ * more than this number of characters, then the printout is
+ * truncated.
+ */
+ var maxPrintString = 800
+
+ /** The maximum number of completion candidates to print for tab
+ * completion without requiring confirmation.
+ */
+ var maxAutoprintCompletion = 250
+
+ /** String unwrapping can be disabled if it is causing issues.
+ * Settings this to false means you will see Strings like "$iw.$iw.".
+ */
+ var unwrapStrings = true
+
+ def deprecation_=(x: Boolean) = {
+ val old = repl.settings.deprecation.value
+ repl.settings.deprecation.value = x
+ if (!old && x) println("Enabled -deprecation output.")
+ else if (old && !x) println("Disabled -deprecation output.")
+ }
+ def deprecation: Boolean = repl.settings.deprecation.value
+
+ def allSettings = Map(
+ "maxPrintString" -> maxPrintString,
+ "maxAutoprintCompletion" -> maxAutoprintCompletion,
+ "unwrapStrings" -> unwrapStrings,
+ "deprecation" -> deprecation
+ )
+
+ private def allSettingsString =
+ allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString
+
+ override def toString = """
+ | SparkInterpreterSettings {
+ | %s
+ | }""".stripMargin.format(allSettingsString)
+}
+
+/* Utilities for the InterpreterSettings class
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/5/24
+ */
+object SparkInterpreterSettings {
+ /** Source code for the InterpreterSettings class. This is
+ * used so that the interpreter is sure to have the code
+ * available.
+ *
+ * XXX I'm not seeing why this degree of defensiveness is necessary.
+ * If files are missing the repl's not going to work, it's not as if
+ * we have string source backups for anything else.
+ */
+ val sourceCodeForClass =
+"""
+package scala.tools.nsc
+
+/** Settings for the interpreter
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/3/24
+ **/
+class SparkInterpreterSettings(repl: Interpreter) {
+ /** A list of paths where :load should look */
+ var loadPath = List(".")
+
+ /** The maximum length of toString to use when printing the result
+ * of an evaluation. 0 means no maximum. If a printout requires
+ * more than this number of characters, then the printout is
+ * truncated.
+ */
+ var maxPrintString = 2400
+
+ def deprecation_=(x: Boolean) = {
+ val old = repl.settings.deprecation.value
+ repl.settings.deprecation.value = x
+ if (!old && x) println("Enabled -deprecation output.")
+ else if (old && !x) println("Disabled -deprecation output.")
+ }
+ def deprecation: Boolean = repl.settings.deprecation.value
+
+ override def toString =
+ "SparkInterpreterSettings {\n" +
+// " loadPath = " + loadPath + "\n" +
+ " maxPrintString = " + maxPrintString + "\n" +
+ "}"
+}
+
+"""
+
+}
diff --git a/core/src/main/scala/spark/repl/SparkJLineReader.scala b/core/src/main/scala/spark/repl/SparkJLineReader.scala
new file mode 100644
index 0000000000..9d761c06fc
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkJLineReader.scala
@@ -0,0 +1,38 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import java.io.File
+import jline.{ ConsoleReader, ArgumentCompletor, History => JHistory }
+
+/** Reads from the console using JLine */
+class SparkJLineReader(interpreter: SparkInterpreter) extends SparkInteractiveReader {
+ def this() = this(null)
+
+ override lazy val history = Some(History(consoleReader))
+ override lazy val completion = Option(interpreter) map (x => new SparkCompletion(x))
+
+ val consoleReader = {
+ val r = new jline.ConsoleReader()
+ r setHistory (History().jhistory)
+ r setBellEnabled false
+ completion foreach { c =>
+ r addCompletor c.jline
+ r setAutoprintThreshhold 250
+ }
+
+ r
+ }
+
+ def readOneLine(prompt: String) = consoleReader readLine prompt
+ val interactive = true
+}
+
diff --git a/core/src/main/scala/spark/repl/SparkSimpleReader.scala b/core/src/main/scala/spark/repl/SparkSimpleReader.scala
new file mode 100644
index 0000000000..2b24c4bf63
--- /dev/null
+++ b/core/src/main/scala/spark/repl/SparkSimpleReader.scala
@@ -0,0 +1,33 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import java.io.{ BufferedReader, PrintWriter }
+import io.{ Path, File, Directory }
+
+/** Reads using standard JDK API */
+class SparkSimpleReader(
+ in: BufferedReader,
+ out: PrintWriter,
+ val interactive: Boolean)
+extends SparkInteractiveReader {
+ def this() = this(Console.in, new PrintWriter(Console.out), true)
+ def this(in: File, out: PrintWriter, interactive: Boolean) = this(in.bufferedReader(), out, interactive)
+
+ def close() = in.close()
+ def readOneLine(prompt: String): String = {
+ if (interactive) {
+ out.print(prompt)
+ out.flush()
+ }
+ in.readLine()
+ }
+}