diff options
-rw-r--r-- | core/lib/kryo-1.04-mod/kryo-1.04-mod.jar | bin | 86081 -> 0 bytes | |||
-rw-r--r-- | core/lib/kryo-1.04-mod/minlog-1.2.jar | bin | 2595 -> 0 bytes | |||
-rw-r--r-- | core/lib/kryo-1.04-mod/objenesis-1.2.jar | bin | 36034 -> 0 bytes | |||
-rw-r--r-- | core/lib/kryo-1.04-mod/reflectasm-1.01.jar | bin | 8135 -> 0 bytes | |||
-rw-r--r-- | core/lib/mesos.jar | bin | 36686 -> 126006 bytes | |||
-rw-r--r-- | core/src/main/scala/spark/Executor.scala | 61 | ||||
-rw-r--r-- | core/src/main/scala/spark/HadoopRDD.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/Job.scala | 6 | ||||
-rw-r--r-- | core/src/main/scala/spark/KryoSerializer.scala | 44 | ||||
-rw-r--r-- | core/src/main/scala/spark/MapOutputTracker.scala | 1 | ||||
-rw-r--r-- | core/src/main/scala/spark/MesosScheduler.scala | 153 | ||||
-rw-r--r-- | core/src/main/scala/spark/ParallelCollection.scala | 2 | ||||
-rw-r--r-- | core/src/main/scala/spark/SimpleJob.scala | 56 | ||||
-rw-r--r-- | core/src/main/scala/spark/Task.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/spark/KryoSerializerSuite.scala | 135 | ||||
-rw-r--r-- | project/SparkBuild.scala | 4 | ||||
-rwxr-xr-x | run | 3 |
17 files changed, 365 insertions, 104 deletions
diff --git a/core/lib/kryo-1.04-mod/kryo-1.04-mod.jar b/core/lib/kryo-1.04-mod/kryo-1.04-mod.jar Binary files differdeleted file mode 100644 index 815c1c8d94..0000000000 --- a/core/lib/kryo-1.04-mod/kryo-1.04-mod.jar +++ /dev/null diff --git a/core/lib/kryo-1.04-mod/minlog-1.2.jar b/core/lib/kryo-1.04-mod/minlog-1.2.jar Binary files differdeleted file mode 100644 index 2fcada1b7e..0000000000 --- a/core/lib/kryo-1.04-mod/minlog-1.2.jar +++ /dev/null diff --git a/core/lib/kryo-1.04-mod/objenesis-1.2.jar b/core/lib/kryo-1.04-mod/objenesis-1.2.jar Binary files differdeleted file mode 100644 index 45cb641683..0000000000 --- a/core/lib/kryo-1.04-mod/objenesis-1.2.jar +++ /dev/null diff --git a/core/lib/kryo-1.04-mod/reflectasm-1.01.jar b/core/lib/kryo-1.04-mod/reflectasm-1.01.jar Binary files differdeleted file mode 100644 index 09179ca473..0000000000 --- a/core/lib/kryo-1.04-mod/reflectasm-1.01.jar +++ /dev/null diff --git a/core/lib/mesos.jar b/core/lib/mesos.jar Binary files differindex eb01ce8a15..f1fde967c4 100644 --- a/core/lib/mesos.jar +++ b/core/lib/mesos.jar diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index d4d80845c5..a2af70989c 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -7,22 +7,24 @@ import java.util.concurrent._ import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer -import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver} -import mesos.{TaskDescription, TaskState, TaskStatus} +import com.google.protobuf.ByteString + +import org.apache.mesos._ +import org.apache.mesos.Protos._ import spark.broadcast._ /** * The Mesos executor for Spark. */ -class Executor extends mesos.Executor with Logging { +class Executor extends org.apache.mesos.Executor with Logging { var classLoader: ClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null override def init(d: ExecutorDriver, args: ExecutorArgs) { // Read spark.* system properties from executor arg - val props = Utils.deserialize[Array[(String, String)]](args.getData) + val props = Utils.deserialize[Array[(String, String)]](args.getData.toByteArray) for ((key, value) <- props) System.setProperty(key, value) @@ -44,40 +46,47 @@ class Executor extends mesos.Executor with Logging { 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) } - override def launchTask(d: ExecutorDriver, desc: TaskDescription) { - // Pull taskId and arg out of TaskDescription because it won't be a - // valid pointer after this method call (TODO: fix this in C++/SWIG) - val taskId = desc.getTaskId - val arg = desc.getArg - threadPool.execute(new TaskRunner(taskId, arg, d)) + override def launchTask(d: ExecutorDriver, task: TaskDescription) { + threadPool.execute(new TaskRunner(task, d)) } - class TaskRunner(taskId: Int, arg: Array[Byte], d: ExecutorDriver) + class TaskRunner(desc: TaskDescription, d: ExecutorDriver) extends Runnable { override def run() = { - logInfo("Running task ID " + taskId) + val tid = desc.getTaskId.getValue + logInfo("Running task ID " + tid) + d.sendStatusUpdate(TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_RUNNING) + .build()) try { SparkEnv.set(env) Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear - val task = Utils.deserialize[Task[Any]](arg, classLoader) + val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) for (gen <- task.generation) // Update generation if any is set env.mapOutputTracker.updateGeneration(gen) - val value = task.run(taskId) + val value = task.run(tid.toInt) val accumUpdates = Accumulators.values val result = new TaskResult(value, accumUpdates) - d.sendStatusUpdate(new TaskStatus( - taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) - logInfo("Finished task ID " + taskId) + d.sendStatusUpdate(TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_FINISHED) + .setData(ByteString.copyFrom(Utils.serialize(result))) + .build()) + logInfo("Finished task ID " + tid) } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason - d.sendStatusUpdate(new TaskStatus( - taskId, TaskState.TASK_FAILED, Utils.serialize(reason))) + d.sendStatusUpdate(TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_FAILED) + .setData(ByteString.copyFrom(Utils.serialize(reason))) + .build()) } case t: Throwable => { // TODO: Handle errors in tasks less dramatically - logError("Exception in task ID " + taskId, t) + logError("Exception in task ID " + tid, t) System.exit(1) } } @@ -131,6 +140,18 @@ class Executor extends mesos.Executor with Logging { val out = new FileOutputStream(localPath) Utils.copyStream(in, out, true) } + + override def error(d: ExecutorDriver, code: Int, message: String) { + logError("Error from Mesos: %s (code %d)".format(message, code)) + } + + override def killTask(d: ExecutorDriver, t: TaskID) { + logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + } + + override def shutdown(d: ExecutorDriver) {} + + override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {} } /** diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/HadoopRDD.scala index 5d8a2d0e35..c87fa844c3 100644 --- a/core/src/main/scala/spark/HadoopRDD.scala +++ b/core/src/main/scala/spark/HadoopRDD.scala @@ -1,7 +1,5 @@ package spark -import mesos.SlaveOffer - import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala index 6abbcbce51..acff8ce561 100644 --- a/core/src/main/scala/spark/Job.scala +++ b/core/src/main/scala/spark/Job.scala @@ -1,14 +1,14 @@ package spark -import mesos._ +import org.apache.mesos._ +import org.apache.mesos.Protos._ /** * Class representing a parallel job in MesosScheduler. Schedules the * job by implementing various callbacks. */ abstract class Job(jobId: Int) { - def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int) - : Option[TaskDescription] + def slaveOffer(s: SlaveOffer, availableCpus: Double): Option[TaskDescription] def statusUpdate(t: TaskStatus): Unit diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 6528297350..bbd5f807a3 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,6 +9,7 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} +import de.javakaffee.kryoserializers.KryoReflectionFactorySupport /** * Zig-zag encoder used to write object sizes to serialization streams. @@ -126,19 +127,18 @@ class KryoSerializer extends Serializer with Logging { } def createKryo(): Kryo = { - val kryo = new Kryo() + val kryo = new KryoReflectionFactorySupport() // Register some commonly used classes val toRegister: Seq[AnyRef] = Seq( // Arrays Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")), - Array(new java.lang.Object), Array(1.toByte), + Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'), // Specialized Tuple2s ("", ""), (1, 1), (1.0, 1.0), (1L, 1L), (1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1), // Scala collections - List(1), immutable.Map(1 -> 1), immutable.HashMap(1 -> 1), - mutable.Map(1 -> 1), mutable.HashMap(1 -> 1), mutable.ArrayBuffer(1), + List(1), mutable.ArrayBuffer(1), // Options and Either Some(1), Left(1), Right(1), // Higher-dimensional tuples @@ -151,15 +151,37 @@ class KryoSerializer extends Serializer with Logging { // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. - kryo.register(None.getClass, new KSerializer { + class SingletonSerializer(obj: AnyRef) extends KSerializer { override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {} - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = None.asInstanceOf[T] - }) - kryo.register(Nil.getClass, new KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {} - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = Nil.asInstanceOf[T] - }) + override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T] + } + kryo.register(None.getClass, new SingletonSerializer(None)) + kryo.register(Nil.getClass, new SingletonSerializer(Nil)) + + // Register maps with a special serializer since they have complex internal structure + class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any]) + extends KSerializer { + override def writeObjectData(buf: ByteBuffer, obj: AnyRef) { + val map = obj.asInstanceOf[scala.collection.Map[Any, Any]] + kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer]) + for ((k, v) <- map) { + kryo.writeClassAndObject(buf, k) + kryo.writeClassAndObject(buf, v) + } + } + override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = { + val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue + val elems = new Array[(Any, Any)](size) + for (i <- 0 until size) + elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf)) + buildMap(elems).asInstanceOf[T] + } + } + kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _)) + // TODO: add support for immutable maps too; this is more annoying because there are many + // subclasses of immutable.Map for small maps (with <= 4 entries) + // Allow the user to register their own classes by setting spark.kryo.registrator val regCls = System.getProperty("spark.kryo.registrator") if (regCls != null) { logInfo("Running user registrator: " + regCls) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index aff73cd4ad..3064936758 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -102,7 +102,6 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]] - println("Got locations: " + fetched.mkString(", ")) serverUris.put(shuffleId, fetched) fetching.synchronized { fetching -= shuffleId diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala index 9776963c5f..9ca316d953 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -12,8 +12,11 @@ import scala.collection.mutable.Map import scala.collection.mutable.Queue import scala.collection.JavaConversions._ -import mesos.{Scheduler => MScheduler} -import mesos._ +import com.google.protobuf.ByteString + +import org.apache.mesos.{Scheduler => MScheduler} +import org.apache.mesos._ +import org.apache.mesos.Protos._ /** * The main Scheduler implementation, which runs jobs on Mesos. Clients should @@ -30,15 +33,25 @@ extends MScheduler with DAGScheduler with Logging "SPARK_LIBRARY_PATH" ) + // Memory used by each executor (in megabytes) + val EXECUTOR_MEMORY = { + if (System.getenv("SPARK_MEM") != null) + memoryStringToMb(System.getenv("SPARK_MEM")) + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + else + 512 + } + // Lock used to wait for scheduler to be registered private var isRegistered = false private val registeredLock = new Object() - private var activeJobs = new HashMap[Int, Job] - private var activeJobsQueue = new Queue[Job] + private val activeJobs = new HashMap[Int, Job] + private val activeJobsQueue = new Queue[Job] - private var taskIdToJobId = new HashMap[Int, Int] - private var jobTasks = new HashMap[Int, HashSet[Int]] + private val taskIdToJobId = new HashMap[String, Int] + private val taskIdToSlaveId = new HashMap[String, String] + private val jobTasks = new HashMap[Int, HashSet[String]] // Incrementing job and task IDs private var nextJobId = 0 @@ -47,6 +60,9 @@ extends MScheduler with DAGScheduler with Logging // Driver for talking to Mesos var driver: SchedulerDriver = null + // Which nodes we have executors on + private val slavesWithExecutors = new HashSet[String] + // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -59,10 +75,10 @@ extends MScheduler with DAGScheduler with Logging return id } - def newTaskId(): Int = { - val id = nextTaskId; + def newTaskId(): TaskID = { + val id = "" + nextTaskId; nextTaskId += 1; - return id + return TaskID.newBuilder().setValue(id).build() } override def start() { @@ -76,7 +92,13 @@ extends MScheduler with DAGScheduler with Logging override def run { val sched = MesosScheduler.this sched.driver = new MesosSchedulerDriver(sched, master) - sched.driver.run() + try { + val ret = sched.driver.run() + logInfo("driver.run() returned with code " + ret) + } catch { + case e: Exception => + logError("driver.run() failed", e) + } } }.start } @@ -92,13 +114,28 @@ extends MScheduler with DAGScheduler with Logging "or the SparkContext constructor") } val execScript = new File(sparkHome, "spark-executor").getCanonicalPath - val params = new JHashMap[String, String] + val params = Params.newBuilder() for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { if (System.getenv(key) != null) { - params("env." + key) = System.getenv(key) + params.addParam(Param.newBuilder() + .setKey("env." + key) + .setValue(System.getenv(key)) + .build()) } } - new ExecutorInfo(execScript, createExecArg(), params) + val memory = Resource.newBuilder() + .setName("mem") + .setType(Resource.Type.SCALAR) + .setScalar(Resource.Scalar.newBuilder() + .setValue(EXECUTOR_MEMORY).build()) + .build() + ExecutorInfo.newBuilder() + .setExecutorId(ExecutorID.newBuilder().setValue("default").build()) + .setUri(execScript) + .setData(ByteString.copyFrom(createExecArg())) + .setParams(params.build()) + .addResources(memory) + .build() } @@ -121,11 +158,12 @@ extends MScheduler with DAGScheduler with Logging activeJobs -= job.getId activeJobsQueue.dequeueAll(x => (x == job)) taskIdToJobId --= jobTasks(job.getId) + taskIdToSlaveId --= jobTasks(job.getId) jobTasks.remove(job.getId) } } - override def registered(d: SchedulerDriver, frameworkId: String) { + override def registered(d: SchedulerDriver, frameworkId: FrameworkID) { logInfo("Registered as framework ID " + frameworkId) registeredLock.synchronized { isRegistered = true @@ -145,30 +183,32 @@ extends MScheduler with DAGScheduler with Logging * our active jobs for tasks in FIFO order. We fill each node with tasks in * a round-robin manner so that tasks are balanced across the cluster. */ - override def resourceOffer( - d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) { + override def resourceOffer(d: SchedulerDriver, oid: OfferID, offers: JList[SlaveOffer]) { synchronized { val tasks = new JArrayList[TaskDescription] - val availableCpus = offers.map(_.getParams.get("cpus").toInt) - val availableMem = offers.map(_.getParams.get("mem").toInt) + val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus")) + val enoughMem = offers.map(o => { + val mem = getResource(o.getResourcesList(), "mem") + val slaveId = o.getSlaveId.getValue + mem > EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId) + }) var launchedTask = false for (job <- activeJobsQueue) { do { launchedTask = false - for (i <- 0 until offers.size.toInt) { - try { - job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match { - case Some(task) => - tasks.add(task) - taskIdToJobId(task.getTaskId) = job.getId - jobTasks(job.getId) += task.getTaskId - availableCpus(i) -= task.getParams.get("cpus").toInt - availableMem(i) -= task.getParams.get("mem").toInt - launchedTask = true - case None => {} - } - } catch { - case e: Exception => logError("Exception in resourceOffer", e) + for (i <- 0 until offers.size if enoughMem(i)) { + job.slaveOffer(offers(i), availableCpus(i)) match { + case Some(task) => + tasks.add(task) + val tid = task.getTaskId.getValue + val sid = offers(i).getSlaveId.getValue + taskIdToJobId(tid) = job.getId + jobTasks(job.getId) += tid + taskIdToSlaveId(tid) = sid + slavesWithExecutors += sid + availableCpus(i) -= getResource(task.getResourcesList(), "cpus") + launchedTask = true + case None => {} } } } while (launchedTask) @@ -179,6 +219,13 @@ extends MScheduler with DAGScheduler with Logging } } + // Helper function to pull out a resource from a Mesos Resources protobuf + def getResource(res: JList[Resource], name: String): Double = { + for (r <- res if r.getName == name) + return r.getScalar.getValue + throw new IllegalArgumentException("No resource called " + name + " in " + res) + } + // Check whether a Mesos task state represents a finished task def isFinished(state: TaskState) = { state == TaskState.TASK_FINISHED || @@ -190,19 +237,24 @@ extends MScheduler with DAGScheduler with Logging override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { try { - taskIdToJobId.get(status.getTaskId) match { + val tid = status.getTaskId.getValue + if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + // We lost the executor on this slave, so remember that it's gone + slavesWithExecutors -= taskIdToSlaveId(tid) + } + taskIdToJobId.get(tid) match { case Some(jobId) => if (activeJobs.contains(jobId)) { activeJobs(jobId).statusUpdate(status) } if (isFinished(status.getState)) { - taskIdToJobId.remove(status.getTaskId) + taskIdToJobId.remove(tid) if (jobTasks.contains(jobId)) - jobTasks(jobId) -= status.getTaskId + jobTasks(jobId) -= tid + taskIdToSlaveId.remove(tid) } case None => - logInfo("Ignoring update from TID " + status.getTaskId + - " because its job is gone") + logInfo("Ignoring update from TID " + tid + " because its job is gone") } } catch { case e: Exception => logError("Exception in statusUpdate", e) @@ -293,4 +345,31 @@ extends MScheduler with DAGScheduler with Logging // Serialize the map as an array of (String, String) pairs return Utils.serialize(props.toArray) } + + override def frameworkMessage(d: SchedulerDriver, s: SlaveID, e: ExecutorID, b: Array[Byte]) {} + + override def slaveLost(d: SchedulerDriver, s: SlaveID) { + slavesWithExecutors.remove(s.getValue) + } + + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a + * number of megabytes. This is used to figure out how much memory to claim + * from Mesos based on the SPARK_MEM environment variable. + */ + def memoryStringToMb(str: String): Int = { + val lower = str.toLowerCase + if (lower.endsWith("k")) + (lower.substring(0, lower.length-1).toLong / 1024).toInt + else if (lower.endsWith("m")) + lower.substring(0, lower.length-1).toInt + else if (lower.endsWith("g")) + lower.substring(0, lower.length-1).toInt * 1024 + else if (lower.endsWith("t")) + lower.substring(0, lower.length-1).toInt * 1024 * 1024 + else // no suffix, so it's just a number in bytes + (lower.toLong / 1024 / 1024).toInt + } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 36121766f5..a2e271b028 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -1,7 +1,5 @@ package spark -import mesos.SlaveOffer - import java.util.concurrent.atomic.AtomicLong @serializable class ParallelCollectionSplit[T: ClassManifest]( diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala index aa1610fb89..2001205878 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -5,7 +5,10 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import mesos._ +import com.google.protobuf.ByteString + +import org.apache.mesos._ +import org.apache.mesos.Protos._ /** @@ -18,9 +21,8 @@ extends Job(jobId) with Logging // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong - // CPUs and memory to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 @@ -31,7 +33,7 @@ extends Job(jobId) with Logging val launched = new Array[Boolean](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) - val tidToIndex = HashMap[Int, Int]() + val tidToIndex = HashMap[String, Int]() var tasksLaunched = 0 var tasksFinished = 0 @@ -126,13 +128,11 @@ extends Job(jobId) with Logging } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int) - : Option[TaskDescription] = { - if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK && - availableMem >= MEM_PER_TASK) { + def slaveOffer(offer: SlaveOffer, availableCpus: Double): Option[TaskDescription] = { + if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) - val host = offer.getHost + val host = offer.getHostname findTask(host, localOnly) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it @@ -143,23 +143,31 @@ extends Job(jobId) with Logging val prefStr = if(preferred) "preferred" else "non-preferred" val message = "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( - jobId, index, taskId, offer.getSlaveId, host, prefStr) + jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr) logInfo(message) // Do various bookkeeping - tidToIndex(taskId) = index + tidToIndex(taskId.getValue) = index launched(index) = true tasksLaunched += 1 if (preferred) lastPreferredLaunchTime = time // Create and return the Mesos task object - val params = new JHashMap[String, String] - params.put("cpus", CPUS_PER_TASK.toString) - params.put("mem", MEM_PER_TASK.toString) + val cpuRes = Resource.newBuilder() + .setName("cpus") + .setType(Resource.Type.SCALAR) + .setScalar(Resource.Scalar.newBuilder() + .setValue(CPUS_PER_TASK).build()) + .build() val serializedTask = Utils.serialize(task) logDebug("Serialized size: " + serializedTask.size) val taskName = "task %d:%d".format(jobId, index) - return Some(new TaskDescription( - taskId, offer.getSlaveId, taskName, params, serializedTask)) + return Some(TaskDescription.newBuilder() + .setTaskId(taskId) + .setSlaveId(offer.getSlaveId) + .setName(taskName) + .addResources(cpuRes) + .setData(ByteString.copyFrom(serializedTask)) + .build()) } case _ => } @@ -182,14 +190,14 @@ extends Job(jobId) with Logging } def taskFinished(status: TaskStatus) { - val tid = status.getTaskId + val tid = status.getTaskId.getValue val index = tidToIndex(tid) if (!finished(index)) { tasksFinished += 1 - logInfo("Finished TID %d (progress: %d/%d)".format( + logInfo("Finished TID %s (progress: %d/%d)".format( tid, tasksFinished, numTasks)) // Deserialize task result - val result = Utils.deserialize[TaskResult[_]](status.getData) + val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray) sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true @@ -202,16 +210,16 @@ extends Job(jobId) with Logging } def taskLost(status: TaskStatus) { - val tid = status.getTaskId + val tid = status.getTaskId.getValue val index = tidToIndex(tid) if (!finished(index)) { - logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index)) + logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index)) launched(index) = false tasksLaunched -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. - if (status.getData != null && status.getData.length > 0) { - val reason = Utils.deserialize[TaskEndReason](status.getData) + if (status.getData != null && status.getData.size > 0) { + val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray) reason match { case fetchFailed: FetchFailed => logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala index 70547445ac..03274167e1 100644 --- a/core/src/main/scala/spark/Task.scala +++ b/core/src/main/scala/spark/Task.scala @@ -1,7 +1,5 @@ package spark -import mesos._ - @serializable class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) { } diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala new file mode 100644 index 0000000000..078071209a --- /dev/null +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -0,0 +1,135 @@ +package spark + +import scala.collection.mutable +import scala.collection.immutable + +import org.scalatest.FunSuite +import com.esotericsoftware.kryo._ + +import SparkContext._ + +class KryoSerializerSuite extends FunSuite { + test("basic types") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T): Unit = + assert(ser.deserialize[T](ser.serialize(t)) === t) + check(1) + check(1L) + check(1.0f) + check(1.0) + check(1.toByte) + check(1.toShort) + check("") + check("hello") + check(Integer.MAX_VALUE) + check(Integer.MIN_VALUE) + check(java.lang.Long.MAX_VALUE) + check(java.lang.Long.MIN_VALUE) + check[String](null) + check(Array(1, 2, 3)) + check(Array(1L, 2L, 3L)) + check(Array(1.0, 2.0, 3.0)) + check(Array(1.0f, 2.9f, 3.9f)) + check(Array("aaa", "bbb", "ccc")) + check(Array("aaa", "bbb", null)) + check(Array(true, false, true)) + check(Array('a', 'b', 'c')) + check(Array[Int]()) + } + + test("pairs") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T): Unit = + assert(ser.deserialize[T](ser.serialize(t)) === t) + check((1, 1)) + check((1, 1L)) + check((1L, 1)) + check((1L, 1L)) + check((1.0, 1)) + check((1, 1.0)) + check((1.0, 1.0)) + check((1.0, 1L)) + check((1L, 1.0)) + check((1.0, 1L)) + check(("x", 1)) + check(("x", 1.0)) + check(("x", 1L)) + check((1, "x")) + check((1.0, "x")) + check((1L, "x")) + check(("x", "x")) + } + + test("Scala data structures") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T): Unit = + assert(ser.deserialize[T](ser.serialize(t)) === t) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + } + + test("custom registrator") { + import spark.test._ + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + + val ser = (new KryoSerializer).newInstance() + def check[T](t: T): Unit = + assert(ser.deserialize[T](ser.serialize(t)) === t) + + check(CaseClass(17, "hello")) + + val c1 = new ClassWithNoArgConstructor + c1.x = 32 + check(c1) + + val c2 = new ClassWithoutNoArgConstructor(47) + check(c2) + + val hashMap = new java.util.HashMap[String, String] + hashMap.put("foo", "bar") + check(hashMap) + + System.clearProperty("spark.kryo.registrator") + } +} + +package test { + case class CaseClass(i: Int, s: String) {} + + class ClassWithNoArgConstructor { + var x: Int = 0 + override def equals(other: Any) = other match { + case c: ClassWithNoArgConstructor => x == c.x + case _ => false + } + } + + class ClassWithoutNoArgConstructor(val x: Int) { + override def equals(other: Any) = other match { + case c: ClassWithoutNoArgConstructor => x == c.x + case _ => false + } + } + + class MyRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.register(classOf[CaseClass]) + k.register(classOf[ClassWithNoArgConstructor]) + k.register(classOf[ClassWithoutNoArgConstructor]) + k.register(classOf[java.util.HashMap[_, _]]) + } + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3f37da5139..858751618e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,7 +39,9 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "com.ning" % "compress-lzf" % "0.7.0", "org.apache.hadoop" % "hadoop-core" % "0.20.2", - "asm" % "asm-all" % "3.3.1" + "asm" % "asm-all" % "3.3.1", + "com.google.protobuf" % "protobuf-java" % "2.3.0", + "de.javakaffee" % "kryo-serializers" % "0.9" )) ++ DepJarPlugin.depJarSettings def replSettings = sharedSettings ++ @@ -22,8 +22,9 @@ if [ "x$MESOS_HOME" != "x" ] ; then fi if [ "x$SPARK_MEM" == "x" ] ; then - SPARK_MEM="300m" + SPARK_MEM="512m" fi +export SPARK_MEM # So that the process sees it and can report it to Mesos # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$SPARK_JAVA_OPTS" |