From 85a35c68401e171df0b72b172a689d8c4e412199 Mon Sep 17 00:00:00 2001 From: Christoph Grothaus Date: Fri, 15 Feb 2013 14:11:34 +0100 Subject: Fix SPARK-698. From ExecutorRunner, launch java directly instead via the run scripts. --- .../scala/spark/deploy/worker/ExecutorRunner.scala | 43 ++++++++++++++++++++-- run | 3 ++ run2.cmd | 3 ++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index de11771c8e..214c44fc88 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -75,9 +75,45 @@ private[spark] class ExecutorRunner( def buildCommandSeq(): Seq[String] = { val command = appDesc.command - val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run" - val runScript = new File(sparkHome, script).getCanonicalPath - Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables) + val runner = if (getEnvOrEmpty("JAVA_HOME") == "") { + "java" + } else { + getEnvOrEmpty("JAVA_HOME") + "/bin/java" + } + // SPARK-698: do not call the run.cmd script, as process.destroy() + // fails to kill a process tree on Windows + Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ + command.arguments.map(substituteVariables) + } + + /* + * Attention: this must always be aligned with the environment variables in the run scripts and the + * way the JAVA_OPTS are assembled there. + */ + def buildJavaOpts(): Seq[String] = { + val _javaLibPath = if (getEnvOrEmpty("SPARK_LIBRARY_PATH") == "") { + "" + } else { + "-Djava.library.path=" + getEnvOrEmpty("SPARK_LIBRARY_PATH") + } + + Seq("-cp", + getEnvOrEmpty("CLASSPATH"), + // SPARK_JAVA_OPTS is overwritten with SPARK_DAEMON_JAVA_OPTS for running the worker + getEnvOrEmpty("SPARK_NONDAEMON_JAVA_OPTS"), + _javaLibPath, + "-Xms" + memory.toString + "M", + "-Xmx" + memory.toString + "M") + .filter(_ != "") + } + + def getEnvOrEmpty(key: String): String = { + val result = System.getenv(key) + if (result == null) { + "" + } else { + result + } } /** Spawn a thread that will redirect a given stream to a file */ @@ -113,7 +149,6 @@ private[spark] class ExecutorRunner( for ((key, value) <- appDesc.command.environment) { env.put(key, value) } - env.put("SPARK_MEM", memory.toString + "m") // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") diff --git a/run b/run index 82b1da005a..b5f693f1fa 100755 --- a/run +++ b/run @@ -22,6 +22,8 @@ fi # values for that; it doesn't need a lot if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} + # Backup current SPARK_JAVA_OPTS for use in ExecutorRunner.scala + SPARK_NONDAEMON_JAVA_OPTS=$SPARK_JAVA_OPTS SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default fi @@ -70,6 +72,7 @@ if [ -e $FWDIR/conf/java-opts ] ; then JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`" fi export JAVA_OPTS +# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" diff --git a/run2.cmd b/run2.cmd index c913a5195e..a93bbad0b9 100644 --- a/run2.cmd +++ b/run2.cmd @@ -22,6 +22,8 @@ if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1 if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1 if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY% +rem Backup current SPARK_JAVA_OPTS for use in ExecutorRunner.scala +if "%RUNNING_DAEMON%"=="1" set SPARK_NONDAEMON_JAVA_OPTS=%SPARK_JAVA_OPTS% if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% rem Check that SCALA_HOME has been specified @@ -42,6 +44,7 @@ rem Set JAVA_OPTS to be able to load native libraries and to set heap size set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% rem Load extra JAVA_OPTS from conf/java-opts, if it exists if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd" +rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! set CORE_DIR=%FWDIR%core set REPL_DIR=%FWDIR%repl -- cgit v1.2.3 From f39f2b7636f52568a556987c8b7f7393299b0351 Mon Sep 17 00:00:00 2001 From: Christoph Grothaus Date: Sun, 24 Feb 2013 21:24:30 +0100 Subject: Incorporate feedback from mateiz: - we do not need getEnvOrEmpty - Instead of saving SPARK_NONDAEMON_JAVA_OPTS, it would be better to modify the scripts to use a different variable name for the JAVA_OPTS they do eventually use --- .../scala/spark/deploy/worker/ExecutorRunner.scala | 24 +++++++--------------- run | 9 ++++---- run2.cmd | 8 ++++---- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 214c44fc88..38216ce62f 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -75,10 +75,10 @@ private[spark] class ExecutorRunner( def buildCommandSeq(): Seq[String] = { val command = appDesc.command - val runner = if (getEnvOrEmpty("JAVA_HOME") == "") { + val runner = if (System.getenv("JAVA_HOME") == null) { "java" } else { - getEnvOrEmpty("JAVA_HOME") + "/bin/java" + System.getenv("JAVA_HOME") + "/bin/java" } // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows @@ -91,29 +91,19 @@ private[spark] class ExecutorRunner( * way the JAVA_OPTS are assembled there. */ def buildJavaOpts(): Seq[String] = { - val _javaLibPath = if (getEnvOrEmpty("SPARK_LIBRARY_PATH") == "") { + val _javaLibPath = if (System.getenv("SPARK_LIBRARY_PATH") == null) { "" } else { - "-Djava.library.path=" + getEnvOrEmpty("SPARK_LIBRARY_PATH") + "-Djava.library.path=" + System.getenv("SPARK_LIBRARY_PATH") } Seq("-cp", - getEnvOrEmpty("CLASSPATH"), - // SPARK_JAVA_OPTS is overwritten with SPARK_DAEMON_JAVA_OPTS for running the worker - getEnvOrEmpty("SPARK_NONDAEMON_JAVA_OPTS"), + System.getenv("CLASSPATH"), + System.getenv("SPARK_JAVA_OPTS"), _javaLibPath, "-Xms" + memory.toString + "M", "-Xmx" + memory.toString + "M") - .filter(_ != "") - } - - def getEnvOrEmpty(key: String): String = { - val result = System.getenv(key) - if (result == null) { - "" - } else { - result - } + .filter(_ != null) } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/run b/run index b5f693f1fa..e1482dafbe 100755 --- a/run +++ b/run @@ -22,9 +22,10 @@ fi # values for that; it doesn't need a lot if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} - # Backup current SPARK_JAVA_OPTS for use in ExecutorRunner.scala - SPARK_NONDAEMON_JAVA_OPTS=$SPARK_JAVA_OPTS - SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default + # Do not overwrite SPARK_JAVA_OPTS environment variable in this script + OUR_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default +else + OUR_JAVA_OPTS=$SPARK_JAVA_OPTS fi if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then @@ -64,7 +65,7 @@ fi export SPARK_MEM # Set JAVA_OPTS to be able to load native libraries and to set heap size -JAVA_OPTS="$SPARK_JAVA_OPTS" +JAVA_OPTS="$OUR_JAVA_OPTS" JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH" JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists diff --git a/run2.cmd b/run2.cmd index a93bbad0b9..8648c0380a 100644 --- a/run2.cmd +++ b/run2.cmd @@ -22,9 +22,9 @@ if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1 if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1 if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY% -rem Backup current SPARK_JAVA_OPTS for use in ExecutorRunner.scala -if "%RUNNING_DAEMON%"=="1" set SPARK_NONDAEMON_JAVA_OPTS=%SPARK_JAVA_OPTS% -if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% +rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script +if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% +if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% rem Check that SCALA_HOME has been specified if not "x%SCALA_HOME%"=="x" goto scala_exists @@ -41,7 +41,7 @@ rem variable so that our process sees it and can report it to Mesos if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m rem Set JAVA_OPTS to be able to load native libraries and to set heap size -set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% +set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% rem Load extra JAVA_OPTS from conf/java-opts, if it exists if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd" rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! -- cgit v1.2.3 From 4aa1205202f26663f59347f25a7d1f03c755545d Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 12:37:29 -0600 Subject: adding typesafe repo to streaming resolvers so that akka-zeromq is found --- project/SparkBuild.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b0b6e21681..44c8058e9d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -162,6 +162,9 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", + resolvers ++= Seq( + "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" + ), libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", "com.github.sgroschupf" % "zkclient" % "0.1", -- cgit v1.2.3 From c1c3682c984c83f75352fc22dcadd3e46058cfaf Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 12:40:44 -0600 Subject: adding checkpoint dir to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 155e785b01..6c9ffa5426 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ streaming-tests.log dependency-reduced-pom.xml .ensime .ensime_lucene +checkpoint -- cgit v1.2.3 From c07087364bac672ed7ded6dfeef00bab628c2f9b Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Mon, 4 Mar 2013 16:37:27 -0500 Subject: Made changes to the SparkContext to have a DynamicVariable for setting local properties that can be passed down the stack. Added an implementation of the fair scheduler --- core/src/main/scala/spark/SparkContext.scala | 38 ++- .../main/scala/spark/scheduler/DAGScheduler.scala | 36 ++- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 5 +- core/src/main/scala/spark/scheduler/Stage.scala | 5 +- core/src/main/scala/spark/scheduler/TaskSet.scala | 4 +- .../cluster/fair/FairClusterScheduler.scala | 341 +++++++++++++++++++++ .../cluster/fair/FairTaskSetManager.scala | 130 ++++++++ 7 files changed, 530 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4957a54c1b..bd2261cf0d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,11 +3,13 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.URI +import java.util.Properties import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ +import scala.util.DynamicVariable import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -72,6 +74,11 @@ class SparkContext( if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") } + + //Set the default task scheduler + if (System.getProperty("spark.cluster.taskscheduler") == null) { + System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.ClusterScheduler") + } private val isLocal = (master == "local" || master.startsWith("local[")) @@ -112,7 +119,7 @@ class SparkContext( } } executorEnvs ++= environment - + // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -137,7 +144,7 @@ class SparkContext( new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) scheduler.initialize(backend) scheduler @@ -153,7 +160,7 @@ class SparkContext( memoryPerSlaveInt, sparkMemEnvInt)) } - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() @@ -169,7 +176,7 @@ class SparkContext( logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) } MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { @@ -206,6 +213,20 @@ class SparkContext( } private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new DynamicVariable[Properties](null) + + def initLocalProperties() { + localProperties.value = new Properties() + } + + def addLocalProperties(key: String, value: String) { + if(localProperties.value == null) { + localProperties.value = new Properties() + } + localProperties.value.setProperty(key,value) + } // Methods for creating RDDs @@ -578,7 +599,7 @@ class SparkContext( val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -649,7 +670,7 @@ class SparkContext( val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) } - + /** * Run a job that can return approximate results. */ @@ -657,12 +678,11 @@ class SparkContext( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = { + timeout: Long): PartialResult[R] = { val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout) + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index c54dce51d7..2ad73f3232 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -4,6 +4,7 @@ import cluster.TaskInfo import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit +import java.util.Properties import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -128,11 +129,11 @@ class DAGScheduler( * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int, properties: Properties): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority) + val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority, properties) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -143,7 +144,7 @@ class DAGScheduler( * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int, properties: Properties): Stage = { if (shuffleDep != None) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -151,7 +152,7 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority, properties), priority, properties) idToStage(id) = stage stageToInfos(stage) = StageInfo(stage) stage @@ -161,7 +162,7 @@ class DAGScheduler( * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int, properties: Properties): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -172,7 +173,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - parents += getShuffleMapStage(shufDep, priority) + parents += getShuffleMapStage(shufDep, priority, properties) case _ => visit(dep.rdd) } @@ -193,7 +194,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) + val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) if (!mapStage.isAvailable) { missing += mapStage } @@ -221,13 +222,14 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, + properties: Properties = null) : (JobSubmitted, JobWaiter[U]) = { assert(partitions.size > 0) val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter) + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) return (toSubmit, waiter) } @@ -237,13 +239,13 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, properties: Properties = null) { if (partitions.size == 0) { return } val (toSubmit, waiter) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler) + finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) eventQueue.put(toSubmit) waiter.awaitResult() match { case JobSucceeded => {} @@ -258,13 +260,13 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], callSite: String, - timeout: Long) + timeout: Long, properties: Properties = null) : PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) + eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties)) return listener.awaitResult() // Will throw an exception if the job fails } @@ -274,9 +276,9 @@ class DAGScheduler( */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId) + val finalStage = newStage(finalRDD, None, runId, properties) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + @@ -458,7 +460,7 @@ class DAGScheduler( myPending ++= tasks logDebug("New pending tasks: " + myPending) taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, stage.properties)) if (!stage.submissionTime.isDefined) { stage.submissionTime = Some(System.currentTimeMillis()) } @@ -663,7 +665,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) + val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) if (!mapStage.isAvailable) { visitedStages += mapStage visit(mapStage.rdd) diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178..79588891e7 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -1,5 +1,8 @@ package spark.scheduler + +import java.util.Properties + import spark.scheduler.cluster.TaskInfo import scala.collection.mutable.Map @@ -20,7 +23,7 @@ private[spark] case class JobSubmitted( partitions: Array[Int], allowLocal: Boolean, callSite: String, - listener: JobListener) + listener: JobListener, properties: Properties) extends DAGSchedulerEvent private[spark] case class CompletionEvent( diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 552061e46b..97afa27a60 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -1,10 +1,12 @@ package spark.scheduler import java.net.URI +import java.util.Properties import spark._ import spark.storage.BlockManagerId + /** * A stage is a set of independent tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run @@ -24,7 +26,8 @@ private[spark] class Stage( val rdd: RDD[_], val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], - val priority: Int) + val priority: Int, + val properties: Properties = null) extends Logging { val isShuffleMap = shuffleDep != None diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index a3002ca477..2498e8a5aa 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -1,10 +1,12 @@ package spark.scheduler +import java.util.Properties + /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ -private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { +private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int, val properties: Properties) { val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala new file mode 100644 index 0000000000..37d98ccb2a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala @@ -0,0 +1,341 @@ +package spark.scheduler.cluster.fair + +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.{TimerTask, Timer} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.control.Breaks._ +import scala.xml._ + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import scala.io.Source + +/** + * An implementation of a fair TaskScheduler, for running tasks on a cluster. Clients should first call + * start(), then submit task sets through the runTasks method. + * + * The current implementation makes the following assumptions: A pool has a fixed configuration of weight. + * Within a pool, it just uses FIFO. + * Also, currently we assume that pools are statically defined + * We currently don't support min shares + */ +private[spark] class FairClusterScheduler(override val sc: SparkContext) + extends ClusterScheduler(sc) + with Logging { + + + val schedulerAllocFile = System.getProperty("mapred.fairscheduler.allocation.file","unspecified") + + val poolNameToPool= new HashMap[String, Pool] + var pools = new ArrayBuffer[Pool] + + loadPoolProperties() + + def loadPoolProperties() { + //first check if the file exists + val file = new File(schedulerAllocFile) + if(!file.exists()) { + //if file does not exist, we just create 1 pool, default + val pool = new Pool("default",100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + else { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ "pool")) { + if((poolNode \ "weight").text != ""){ + val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = "+pool.weight) + } else { + val pool = new Pool((poolNode \ "@name").text,100) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = 100") + } + } + if(!poolNameToPool.contains("default")) { + val pool = new Pool("default", 100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + + } + } + + def taskFinished(manager: TaskSetManager) { + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + this.synchronized { + //have to check that poolName exists + if(poolNameToPool.contains(poolName)) + { + poolNameToPool(poolName).numRunningTasks -= 1 + } + else + { + poolNameToPool("default").numRunningTasks -= 1 + } + } + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + + + var poolName = "default" + if(taskSet.properties != null) + poolName = taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + this.synchronized { + if(poolNameToPool.contains(poolName)) + { + val manager = new FairTaskSetManager(this, taskSet) + poolNameToPool(poolName).activeTaskSetsQueue += manager + activeTaskSets(taskSet.id) = manager + //activeTaskSetsQueue += manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool "+poolName) + } + else //If the pool name does not exists, where do we put them? We put them in default + { + val manager = new FairTaskSetManager(this, taskSet) + poolNameToPool("default").activeTaskSetsQueue += manager + activeTaskSets(taskSet.id) = manager + //activeTaskSetsQueue += manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool default") + } + if (hasReceivedTask == false) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true; + + } + backend.reviveOffers() + } + + override def taskSetFinished(manager: TaskSetManager) { + + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + + this.synchronized { + //have to check that poolName exists + if(poolNameToPool.contains(poolName)) + { + poolNameToPool(poolName).activeTaskSetsQueue -= manager + } + else + { + poolNameToPool("default").activeTaskSetsQueue -= manager + } + //activeTaskSetsQueue -= manager + activeTaskSets -= manager.taskSet.id + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + //backend.reviveOffers() + } + + /** + * This is the comparison function used for sorting to determine which + * pool to allocate next based on fairness. + * The algorithm is as follows: we sort by the pool's running tasks to weight ratio + * (pools number running tast / pool's weight) + */ + def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = { + val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble + val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble + var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2) + if (res == 0) { + //Jobs are tied in fairness ratio. We break the tie by name + res = pool1.name.compareTo(pool2.name) + } + if (res < 0) + return true + else + return false + } + + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a fair manner so + * that tasks are balanced across the cluster. + */ + override def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { + synchronized { + SparkEnv.set(sc.env) + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.hostname + if (!executorsByHost.contains(o.hostname)) { + executorsByHost(o.hostname) = new HashSet() + } + } + // Build a list of tasks to assign to each slave + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + var launchedTask = false + + for (i <- 0 until offers.size) { //we loop through the list of offers + val execId = offers(i).executorId + val host = offers(i).hostname + var breakOut = false + while(availableCpus(i) > 0 && !breakOut) { + breakable{ + launchedTask = false + for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools + if(!pool.activeTaskSetsQueue.isEmpty) { + //sort the tasksetmanager in the pool + pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId)) + for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool +// val manager = pool.activeTaskSetsQueue.head + //Make an offer + manager.slaveOffer(execId, host, availableCpus(i)) match { + case Some(task) => + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + pool.numRunningTasks += 1 + launchedTask = true + logInfo("launched task for pool"+pool.name); + break + case None => {} + } + } + } + } + //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer + if(!launchedTask){ + breakOut = true + } + } + } + } + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + } + + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var taskSetToUpdate: Option[TaskSetManager] = None + var failedExecutor: Option[String] = None + var taskFailed = false + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (activeTaskSets.contains(taskSetId)) { + taskSetToUpdate = Some(activeTaskSets(taskSetId)) + } + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + if (state == TaskState.FAILED) { + taskFailed = true + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock + if (taskSetToUpdate != None) { + taskSetToUpdate.get.statusUpdate(tid, state, serializedData) + } + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) + backend.reviveOffers() + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + backend.reviveOffers() + } + } + + // Check for speculatable tasks in all our active jobs. + override def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + for (pool <- pools) { + for (ts <- pool.activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + } + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + for (pool <- pools) { + pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } + } + +} + +/** + * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight and minshare + */ +class Pool(val name: String, val weight: Int) +{ + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + var numRunningTasks: Int = 0 +} diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala new file mode 100644 index 0000000000..4b0277d2d5 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala @@ -0,0 +1,130 @@ +package spark.scheduler.cluster.fair + +import scala.collection.mutable.ArrayBuffer + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +/** + * Schedules the tasks within a single TaskSet in the FairClusterScheduler. + */ +private[spark] class FairTaskSetManager(sched: FairClusterScheduler, override val taskSet: TaskSet) extends TaskSetManager(sched, taskSet) with Logging { + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + if (locations.size == 0) { + pendingTasksWithNoPrefs += index + } else { + for (host <- locations) { + val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + list += index + } + } + allPendingTasks += index + } + + override def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + sched.taskFinished(this) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + override def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + //Bookkeeping necessary for the pools in the scheduler + sched.taskFinished(this) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + return + + case ef: ExceptionFailure => + val key = ef.exception.toString + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } +} \ No newline at end of file -- cgit v1.2.3 From 54ed7c4af4591ebfec31bd168f830ef3ac01a41f Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Mon, 4 Mar 2013 16:57:46 -0500 Subject: Changed the name of the system property to set the allocation xml --- .../main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala index 37d98ccb2a..591736faa2 100644 --- a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala @@ -31,7 +31,7 @@ private[spark] class FairClusterScheduler(override val sc: SparkContext) with Logging { - val schedulerAllocFile = System.getProperty("mapred.fairscheduler.allocation.file","unspecified") + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") val poolNameToPool= new HashMap[String, Pool] var pools = new ArrayBuffer[Pool] -- cgit v1.2.3 From b5325182a3e92ff80185850c39cf70680fca46b7 Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Fri, 8 Mar 2013 15:15:59 -0500 Subject: Updated/Refactored the Fair Task Scheduler. It does not inherit ClusterScheduler anymore. Rather, ClusterScheduler internally uses TaskSetQueuesManager that handles the scheduling of taskset queues. This is the class that should be extended to support other scheduling policies --- core/src/main/scala/spark/SparkContext.scala | 19 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 43 ++- .../cluster/FIFOTaskSetQueuesManager.scala | 63 ++++ .../cluster/FairTaskSetQueuesManager.scala | 183 +++++++++++ .../spark/scheduler/cluster/TaskSetManager.scala | 2 + .../scheduler/cluster/TaskSetQueuesManager.scala | 19 ++ .../cluster/fair/FairClusterScheduler.scala | 341 --------------------- .../cluster/fair/FairTaskSetManager.scala | 130 -------- 8 files changed, 312 insertions(+), 488 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bd2261cf0d..f6ee399898 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -39,7 +39,7 @@ import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import spark.scheduler._ import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} +import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler, TaskSetQueuesManager} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerUI import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -77,7 +77,7 @@ class SparkContext( //Set the default task scheduler if (System.getProperty("spark.cluster.taskscheduler") == null) { - System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.ClusterScheduler") + System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.FIFOTaskSetQueuesManager") } private val isLocal = (master == "local" || master.startsWith("local[")) @@ -144,9 +144,10 @@ class SparkContext( new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => - val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + scheduler.initialize(backend, taskSetQueuesManager) scheduler case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => @@ -160,12 +161,13 @@ class SparkContext( memoryPerSlaveInt, sparkMemEnvInt)) } - val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + scheduler.initialize(backend, taskSetQueuesManager) backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() } @@ -176,7 +178,7 @@ class SparkContext( logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) } MesosNativeLibrary.load() - val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { @@ -184,7 +186,8 @@ class SparkContext( } else { new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) } - scheduler.initialize(backend) + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + scheduler.initialize(backend, taskSetQueuesManager) scheduler } } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 26fdef101b..0b5bf7a86c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -27,7 +27,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong val activeTaskSets = new HashMap[String, TaskSetManager] - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + // var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @@ -61,13 +61,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var backend: SchedulerBackend = null val mapOutputTracker = SparkEnv.get.mapOutputTracker + + var taskSetQueuesManager: TaskSetQueuesManager = null override def setListener(listener: TaskSchedulerListener) { this.listener = listener } - def initialize(context: SchedulerBackend) { + def initialize(context: SchedulerBackend, taskSetQueuesManager: TaskSetQueuesManager) { backend = context + this.taskSetQueuesManager = taskSetQueuesManager } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -99,7 +102,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) this.synchronized { val manager = new TaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager - activeTaskSetsQueue += manager + taskSetQueuesManager.addTaskSetManager(manager) taskSetTaskIds(taskSet.id) = new HashSet[Long]() if (hasReceivedTask == false) { @@ -122,13 +125,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def taskSetFinished(manager: TaskSetManager) { this.synchronized { activeTaskSets -= manager.taskSet.id - activeTaskSetsQueue -= manager + taskSetQueuesManager.removeTaskSetManager(manager) taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) } } + def taskFinished(manager: TaskSetManager) { + this.synchronized { + taskSetQueuesManager.taskFinished(manager) + } + } + /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so @@ -144,8 +153,26 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHost(o.hostname) = new HashSet() } } + // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val taskSetIds = taskSetQueuesManager.receiveOffer(tasks, offers) + //We populate the necessary bookkeeping structures + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).hostname + for(j <- 0 until tasks(i).size) { + val tid = tasks(i)(j).taskId + val taskSetid = taskSetIds(i)(j) + taskIdToTaskSetId(tid) = taskSetid + taskSetTaskIds(taskSetid) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + } + } + + /*val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray var launchedTask = false for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { @@ -170,7 +197,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } } while (launchedTask) - } + }*/ if (tasks.size > 0) { hasLaunchedTask = true } @@ -264,9 +291,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def checkSpeculatableTasks() { var shouldRevive = false synchronized { - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } + shouldRevive = taskSetQueuesManager.checkSpeculatableTasks() } if (shouldRevive) { backend.reviveOffers() @@ -309,6 +334,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHost -= host } executorIdToHost -= executorId - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + taskSetQueuesManager.removeExecutor(executorId, host) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala new file mode 100644 index 0000000000..99a9c94222 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala @@ -0,0 +1,63 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer + +import spark.Logging + +/** + * A FIFO Implementation of the TaskSetQueuesManager + */ +private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with Logging { + + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + + override def addTaskSetManager(manager: TaskSetManager) { + activeTaskSetsQueue += manager + } + + override def removeTaskSetManager(manager: TaskSetManager) { + activeTaskSetsQueue -= manager + } + + override def taskFinished(manager: TaskSetManager) { + //do nothing + } + + override def removeExecutor(executorId: String, host: String) { + activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } + + override def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] = { + val taskSetIds = offers.map(o => new ArrayBuffer[String](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + var launchedTask = false + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + do { + launchedTask = false + for (i <- 0 until offers.size) { + val execId = offers(i).executorId + val host = offers(i).hostname + manager.slaveOffer(execId, host, availableCpus(i)) match { + case Some(task) => + tasks(i) += task + taskSetIds(i) += manager.taskSet.id + availableCpus(i) -= 1 + launchedTask = true + + case None => {} + } + } + } while (launchedTask) + } + return taskSetIds + } + + override def checkSpeculatableTasks(): Boolean = { + var shouldRevive = false + for (ts <- activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + return shouldRevive + } + +} \ No newline at end of file diff --git a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala new file mode 100644 index 0000000000..ca308a5229 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala @@ -0,0 +1,183 @@ +package spark.scheduler.cluster + +import java.io.{File, FileInputStream, FileOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.control.Breaks._ +import scala.xml._ + +import spark.Logging + +/** + * A Fair Implementation of the TaskSetQueuesManager + * + * The current implementation makes the following assumptions: A pool has a fixed configuration of weight. + * Within a pool, it just uses FIFO. + * Also, currently we assume that pools are statically defined + * We currently don't support min shares + */ +private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with Logging { + + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") + val poolNameToPool= new HashMap[String, Pool] + var pools = new ArrayBuffer[Pool] + + loadPoolProperties() + + def loadPoolProperties() { + //first check if the file exists + val file = new File(schedulerAllocFile) + if(!file.exists()) { + //if file does not exist, we just create 1 pool, default + val pool = new Pool("default",100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + else { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ "pool")) { + if((poolNode \ "weight").text != ""){ + val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = "+pool.weight) + } else { + val pool = new Pool((poolNode \ "@name").text,100) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = 100") + } + } + if(!poolNameToPool.contains("default")) { + val pool = new Pool("default", 100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + + } + } + + override def addTaskSetManager(manager: TaskSetManager) { + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + if(poolNameToPool.contains(poolName)) + poolNameToPool(poolName).activeTaskSetsQueue += manager + else + poolNameToPool("default").activeTaskSetsQueue += manager + logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) + + } + + override def removeTaskSetManager(manager: TaskSetManager) { + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + if(poolNameToPool.contains(poolName)) + poolNameToPool(poolName).activeTaskSetsQueue -= manager + else + poolNameToPool("default").activeTaskSetsQueue -= manager + } + + override def taskFinished(manager: TaskSetManager) { + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + if(poolNameToPool.contains(poolName)) + poolNameToPool(poolName).numRunningTasks -= 1 + else + poolNameToPool("default").numRunningTasks -= 1 + } + + override def removeExecutor(executorId: String, host: String) { + for (pool <- pools) { + pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } + } + + /** + * This is the comparison function used for sorting to determine which + * pool to allocate next based on fairness. + * The algorithm is as follows: we sort by the pool's running tasks to weight ratio + * (pools number running tast / pool's weight) + */ + def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = { + val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble + val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble + var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2) + if (res == 0) { + //Jobs are tied in fairness ratio. We break the tie by name + res = pool1.name.compareTo(pool2.name) + } + if (res < 0) + return true + else + return false + } + + override def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] = { + val taskSetIds = offers.map(o => new ArrayBuffer[String](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + var launchedTask = false + + for (i <- 0 until offers.size) { //we loop through the list of offers + val execId = offers(i).executorId + val host = offers(i).hostname + var breakOut = false + while(availableCpus(i) > 0 && !breakOut) { + breakable{ + launchedTask = false + for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools + if(!pool.activeTaskSetsQueue.isEmpty) { + //sort the tasksetmanager in the pool + pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId)) + for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool + //Make an offer + manager.slaveOffer(execId, host, availableCpus(i)) match { + case Some(task) => + tasks(i) += task + taskSetIds(i) += manager.taskSet.id + availableCpus(i) -= 1 + pool.numRunningTasks += 1 + launchedTask = true + logInfo("launched task for pool"+pool.name); + break + case None => {} + } + } + } + } + //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer + if(!launchedTask){ + breakOut = true + } + } + } + } + return taskSetIds + } + + override def checkSpeculatableTasks(): Boolean = { + var shouldRevive = false + for (pool <- pools) { + for (ts <- pool.activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + } + return shouldRevive + } +} + +/** + * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight + */ +class Pool(val name: String, val weight: Int) +{ + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + var numRunningTasks: Int = 0 +} \ No newline at end of file diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c9f2c48804..015092b60b 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -253,6 +253,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } val index = info.index info.markSuccessful() + sched.taskFinished(this) if (!finished(index)) { tasksFinished += 1 logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( @@ -281,6 +282,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } val index = info.index info.markFailed() + sched.taskFinished(this) if (!finished(index)) { logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala new file mode 100644 index 0000000000..b0c30e9e8b --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala @@ -0,0 +1,19 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer + +/** + * An interface for managing TaskSet queue/s that allows plugging different policy for + * offering tasks to resources + * + */ +private[spark] trait TaskSetQueuesManager { + def addTaskSetManager(manager: TaskSetManager): Unit + def removeTaskSetManager(manager: TaskSetManager): Unit + def taskFinished(manager: TaskSetManager): Unit + def removeExecutor(executorId: String, host: String): Unit + //The receiveOffers function, accepts tasks and offers. It populates the tasks to the actual task from TaskSet + //It returns a list of TaskSet ID that corresponds to each assigned tasks + def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] + def checkSpeculatableTasks(): Boolean +} \ No newline at end of file diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala deleted file mode 100644 index 591736faa2..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala +++ /dev/null @@ -1,341 +0,0 @@ -package spark.scheduler.cluster.fair - -import java.io.{File, FileInputStream, FileOutputStream} -import java.util.{TimerTask, Timer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.util.control.Breaks._ -import scala.xml._ - -import spark._ -import spark.TaskState.TaskState -import spark.scheduler._ -import spark.scheduler.cluster._ -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import scala.io.Source - -/** - * An implementation of a fair TaskScheduler, for running tasks on a cluster. Clients should first call - * start(), then submit task sets through the runTasks method. - * - * The current implementation makes the following assumptions: A pool has a fixed configuration of weight. - * Within a pool, it just uses FIFO. - * Also, currently we assume that pools are statically defined - * We currently don't support min shares - */ -private[spark] class FairClusterScheduler(override val sc: SparkContext) - extends ClusterScheduler(sc) - with Logging { - - - val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") - - val poolNameToPool= new HashMap[String, Pool] - var pools = new ArrayBuffer[Pool] - - loadPoolProperties() - - def loadPoolProperties() { - //first check if the file exists - val file = new File(schedulerAllocFile) - if(!file.exists()) { - //if file does not exist, we just create 1 pool, default - val pool = new Pool("default",100) - pools += pool - poolNameToPool("default") = pool - logInfo("Created a default pool with weight = 100") - } - else { - val xml = XML.loadFile(file) - for (poolNode <- (xml \\ "pool")) { - if((poolNode \ "weight").text != ""){ - val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt) - pools += pool - poolNameToPool((poolNode \ "@name").text) = pool - logInfo("Created pool "+ pool.name +"with weight = "+pool.weight) - } else { - val pool = new Pool((poolNode \ "@name").text,100) - pools += pool - poolNameToPool((poolNode \ "@name").text) = pool - logInfo("Created pool "+ pool.name +"with weight = 100") - } - } - if(!poolNameToPool.contains("default")) { - val pool = new Pool("default", 100) - pools += pool - poolNameToPool("default") = pool - logInfo("Created a default pool with weight = 100") - } - - } - } - - def taskFinished(manager: TaskSetManager) { - var poolName = "default" - if(manager.taskSet.properties != null) - poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - - this.synchronized { - //have to check that poolName exists - if(poolNameToPool.contains(poolName)) - { - poolNameToPool(poolName).numRunningTasks -= 1 - } - else - { - poolNameToPool("default").numRunningTasks -= 1 - } - } - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - - - var poolName = "default" - if(taskSet.properties != null) - poolName = taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - - this.synchronized { - if(poolNameToPool.contains(poolName)) - { - val manager = new FairTaskSetManager(this, taskSet) - poolNameToPool(poolName).activeTaskSetsQueue += manager - activeTaskSets(taskSet.id) = manager - //activeTaskSetsQueue += manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool "+poolName) - } - else //If the pool name does not exists, where do we put them? We put them in default - { - val manager = new FairTaskSetManager(this, taskSet) - poolNameToPool("default").activeTaskSetsQueue += manager - activeTaskSets(taskSet.id) = manager - //activeTaskSetsQueue += manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool default") - } - if (hasReceivedTask == false) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true; - - } - backend.reviveOffers() - } - - override def taskSetFinished(manager: TaskSetManager) { - - var poolName = "default" - if(manager.taskSet.properties != null) - poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - - - this.synchronized { - //have to check that poolName exists - if(poolNameToPool.contains(poolName)) - { - poolNameToPool(poolName).activeTaskSetsQueue -= manager - } - else - { - poolNameToPool("default").activeTaskSetsQueue -= manager - } - //activeTaskSetsQueue -= manager - activeTaskSets -= manager.taskSet.id - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - //backend.reviveOffers() - } - - /** - * This is the comparison function used for sorting to determine which - * pool to allocate next based on fairness. - * The algorithm is as follows: we sort by the pool's running tasks to weight ratio - * (pools number running tast / pool's weight) - */ - def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = { - val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble - val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble - var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2) - if (res == 0) { - //Jobs are tied in fairness ratio. We break the tie by name - res = pool1.name.compareTo(pool2.name) - } - if (res < 0) - return true - else - return false - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a fair manner so - * that tasks are balanced across the cluster. - */ - override def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { - synchronized { - SparkEnv.set(sc.env) - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.hostname - if (!executorsByHost.contains(o.hostname)) { - executorsByHost(o.hostname) = new HashSet() - } - } - // Build a list of tasks to assign to each slave - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - var launchedTask = false - - for (i <- 0 until offers.size) { //we loop through the list of offers - val execId = offers(i).executorId - val host = offers(i).hostname - var breakOut = false - while(availableCpus(i) > 0 && !breakOut) { - breakable{ - launchedTask = false - for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools - if(!pool.activeTaskSetsQueue.isEmpty) { - //sort the tasksetmanager in the pool - pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId)) - for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool -// val manager = pool.activeTaskSetsQueue.head - //Make an offer - manager.slaveOffer(execId, host, availableCpus(i)) match { - case Some(task) => - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - pool.numRunningTasks += 1 - launchedTask = true - logInfo("launched task for pool"+pool.name); - break - case None => {} - } - } - } - } - //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer - if(!launchedTask){ - breakOut = true - } - } - } - } - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - } - - override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var taskSetToUpdate: Option[TaskSetManager] = None - var failedExecutor: Option[String] = None - var taskFailed = false - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (activeTaskSets.contains(taskSetId)) { - taskSetToUpdate = Some(activeTaskSets(taskSetId)) - } - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - if (state == TaskState.FAILED) { - taskFailed = true - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock - if (taskSetToUpdate != None) { - taskSetToUpdate.get.statusUpdate(tid, state, serializedData) - } - if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) - backend.reviveOffers() - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } - } - - // Check for speculatable tasks in all our active jobs. - override def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - for (pool <- pools) { - for (ts <- pool.activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } - } - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - for (pool <- pools) { - pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) - } - } - -} - -/** - * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight and minshare - */ -class Pool(val name: String, val weight: Int) -{ - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] - var numRunningTasks: Int = 0 -} diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala deleted file mode 100644 index 4b0277d2d5..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala +++ /dev/null @@ -1,130 +0,0 @@ -package spark.scheduler.cluster.fair - -import scala.collection.mutable.ArrayBuffer - -import spark._ -import spark.scheduler._ -import spark.scheduler.cluster._ -import spark.TaskState.TaskState -import java.nio.ByteBuffer - -/** - * Schedules the tasks within a single TaskSet in the FairClusterScheduler. - */ -private[spark] class FairTaskSetManager(sched: FairClusterScheduler, override val taskSet: TaskSet) extends TaskSetManager(sched, taskSet) with Logging { - - // Add a task to all the pending-task lists that it should be on. - private def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (locations.size == 0) { - pendingTasksWithNoPrefs += index - } else { - for (host <- locations) { - val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) - list += index - } - } - allPendingTasks += index - } - - override def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - sched.taskFinished(this) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - //Bookkeeping necessary for the pools in the scheduler - sched.taskFinished(this) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - return - - case ef: ExceptionFailure => - val key = ef.exception.toString - val now = System.currentTimeMillis - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } -} \ No newline at end of file -- cgit v1.2.3 From f5b1fecb9fc2e7d389810bfff5298f42b313f208 Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Fri, 8 Mar 2013 15:38:32 -0500 Subject: Cleaned up the code --- .../spark/scheduler/cluster/ClusterScheduler.scala | 29 +--------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 0b5bf7a86c..5e960eb59d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -27,7 +27,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong val activeTaskSets = new HashMap[String, TaskSetManager] - // var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] @@ -171,33 +170,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHost(host) += execId } } - - /*val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - var launchedTask = false - for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { - case Some(task) => - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - - case None => {} - } - } - } while (launchedTask) - }*/ + if (tasks.size > 0) { hasLaunchedTask = true } -- cgit v1.2.3 From 0b64e5f1ac0492aac6fca383c7877fbfce7d4cf1 Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Fri, 8 Mar 2013 15:44:36 -0500 Subject: Removed some commented code --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f6ee399898..6eccb501c7 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -144,7 +144,7 @@ class SparkContext( new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this) val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] scheduler.initialize(backend, taskSetQueuesManager) @@ -161,7 +161,7 @@ class SparkContext( memoryPerSlaveInt, sparkMemEnvInt)) } - val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() @@ -178,7 +178,7 @@ class SparkContext( logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) } MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this)//Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] + val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { -- cgit v1.2.3 From 42822cf95de71039988e22d8690ba6a4bd639227 Mon Sep 17 00:00:00 2001 From: seanm Date: Wed, 13 Mar 2013 11:40:42 -0600 Subject: changing streaming resolver for akka --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 44c8058e9d..7e65979a5d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -163,7 +163,7 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", resolvers ++= Seq( - "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" + "Akka Repository" at "http://repo.akka.io/releases/" ), libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", -- cgit v1.2.3 From cfa8e769a86664722f47182fa572179e8beadcb7 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 11 Mar 2013 17:16:15 -0600 Subject: KafkaInputDStream improvements. Allows more Kafka configurability --- .../scala/spark/streaming/StreamingContext.scala | 22 +++++++++- .../streaming/dstream/KafkaInputDStream.scala | 48 ++++++++++++++-------- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 25c67b279b..4e1732adf5 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -199,7 +199,7 @@ class StreamingContext private ( } /** - * Create an input stream that pulls messages form a Kafka Broker. + * Create an input stream that pulls messages from a Kafka Broker. * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -216,7 +216,25 @@ class StreamingContext private ( initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel) + val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * @param storageLevel Storage level to use for storing the received objects + */ + def kafkaStream[T: ClassManifest]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index dc7139cc27..f769fc1cc3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -12,6 +12,8 @@ import kafka.message.{Message, MessageSet, MessageAndMetadata} import kafka.serializer.StringDecoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ +import kafka.utils.ZKStringSerializer +import org.I0Itec.zkclient._ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ @@ -23,8 +25,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part /** * Input stream that pulls messages from a Kafka Broker. * - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param initialOffsets Optional initial offsets for each of the partitions to consume. @@ -34,8 +35,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, - zkQuorum: String, - groupId: String, + kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel @@ -43,19 +43,16 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(zkQuorum: String, groupId: String, +class KafkaReceiver(kafkaParams: Map[String, String], topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any] { - // Timeout for establishing a connection to Zookeper in ms. - val ZK_TIMEOUT = 10000 - // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka @@ -72,20 +69,24 @@ class KafkaReceiver(zkQuorum: String, groupId: String, // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) logInfo("Initial offsets: " + initialOffsets.toString) - // Zookeper connection properties + // Kafka connection properties val props = new Properties() - props.put("zk.connect", zkQuorum) - props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) - props.put("groupid", groupId) + kafkaParams.foreach(param => props.put(param._1, param._2)) // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zkQuorum) + logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zkQuorum) + logInfo("Connected to " + kafkaParams("zk.connect")) + + // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // consumer group zk node. + if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) + } // If specified, set the topic offset setOffsets(initialOffsets) @@ -97,7 +98,6 @@ class KafkaReceiver(zkQuorum: String, groupId: String, topicMessageStreams.values.foreach { streams => streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } } - } // Overwrites the offets in Zookeper. @@ -122,4 +122,18 @@ class KafkaReceiver(zkQuorum: String, groupId: String, } } } + + // Handles cleanup of consumer group znode. Lifted with love from Kafka's + // ConsumerConsole.scala tryCleanupZookeeper() + private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { + try { + val dir = "/consumers/" + groupId + logInfo("Cleaning up temporary zookeeper data under " + dir + ".") + val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) + zk.deleteRecursive(dir) + zk.close() + } catch { + case _ => // swallow + } + } } -- cgit v1.2.3 From d06928321194b11e082986cd2bb2737d9bc3b698 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:25:35 -0600 Subject: fixing memory leak in kafka MessageHandler --- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index f769fc1cc3..d674b6ee87 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -114,11 +114,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { logInfo("Starting MessageHandler.") - stream.takeWhile { msgAndMetadata => + for (msgAndMetadata <- stream) { blockGenerator += msgAndMetadata.message - // Keep on handling messages - - true } } } -- cgit v1.2.3 From 33fa1e7e4aca4d9e0edf65d2b768b569305fd044 Mon Sep 17 00:00:00 2001 From: seanm Date: Thu, 14 Mar 2013 23:32:52 -0600 Subject: removing dependency on ZookeeperConsumerConnector + purging last relic of kafka reliability that never solidified (ie- setOffsets) --- .../scala/spark/streaming/StreamingContext.scala | 9 ++----- .../streaming/api/java/JavaStreamingContext.scala | 28 ---------------------- .../streaming/dstream/KafkaInputDStream.scala | 28 ++++------------------ 3 files changed, 6 insertions(+), 59 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4e1732adf5..bb7f216ca7 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -204,8 +204,6 @@ class StreamingContext private ( * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -213,11 +211,10 @@ class StreamingContext private ( zkQuorum: String, groupId: String, topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, initialOffsets, storageLevel) + kafkaStream[T](kafkaParams, topics, storageLevel) } /** @@ -225,16 +222,14 @@ class StreamingContext private ( * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest]( kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, initialOffsets, storageLevel) + val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f3b40b5b88..2373f4824a 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -84,39 +84,12 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. - */ - def kafkaStream[T]( - zkQuorum: String, - groupId: String, - topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( - zkQuorum, - groupId, - Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) - } - - /** - * Create an input stream that pulls messages form a Kafka Broker. - * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). - * @param groupId The group id for this consumer. - * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( zkQuorum: String, groupId: String, topics: JMap[String, JInt], - initialOffsets: JMap[KafkaPartitionKey, JLong], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = @@ -125,7 +98,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), - Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index d674b6ee87..c6da1a7f70 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,17 +19,12 @@ import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -// Key for a specific Kafka Partition: (broker, topic, group, part) -case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) - /** * Input stream that pulls messages from a Kafka Broker. * * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. - * @param initialOffsets Optional initial offsets for each of the partitions to consume. - * By default the value is pulled from zookeper. * @param storageLevel RDD storage level. */ private[streaming] @@ -37,26 +32,25 @@ class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], - initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, initialOffsets, storageLevel) + new KafkaReceiver(kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] class KafkaReceiver(kafkaParams: Map[String, String], - topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + topics: Map[String, Int], storageLevel: StorageLevel) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka - var consumerConnector : ZookeeperConsumerConnector = null + var consumerConnector : ConsumerConnector = null def onStop() { blockGenerator.stop() @@ -70,7 +64,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid")) - logInfo("Initial offsets: " + initialOffsets.toString) // Kafka connection properties val props = new Properties() @@ -79,7 +72,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], // Create the connection to the cluster logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect")) val consumerConfig = new ConsumerConfig(props) - consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) // When autooffset.reset is 'smallest', it is our responsibility to try and whack the @@ -88,9 +81,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } - // If specified, set the topic offset - setOffsets(initialOffsets) - // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) @@ -100,16 +90,6 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Overwrites the offets in Zookeper. - private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) { - offsets.foreach { case(key, offset) => - val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) - val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, - topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) - } - } - // Handles Kafka Messages private class MessageHandler(stream: KafkaStream[String]) extends Runnable { def run() { -- cgit v1.2.3 From 5892393140eb024a32585b6d5b51146ddde8f63a Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Sat, 16 Mar 2013 11:13:38 +0800 Subject: refactor fair scheduler implementation 1.Chage "pool" properties to be the memeber of ActiveJob 2.Abstract the Schedulable of Pool and TaskSetManager 3.Abstract the FIFO and FS comparator algorithm 4.Miscellaneous changing of class define and construction --- .../src/main/scala/spark/scheduler/ActiveJob.scala | 5 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 32 +-- core/src/main/scala/spark/scheduler/Stage.scala | 4 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 48 ++-- .../cluster/FIFOTaskSetQueuesManager.scala | 37 ++- .../cluster/FairTaskSetQueuesManager.scala | 248 ++++++++++----------- .../main/scala/spark/scheduler/cluster/Pool.scala | 92 ++++++++ .../spark/scheduler/cluster/Schedulable.scala | 21 ++ .../scheduler/cluster/SchedulingAlgorithm.scala | 69 ++++++ .../spark/scheduler/cluster/SchedulingMode.scala | 8 + .../spark/scheduler/cluster/TaskDescription.scala | 1 + .../spark/scheduler/cluster/TaskSetManager.scala | 38 +++- .../scheduler/cluster/TaskSetQueuesManager.scala | 6 +- 13 files changed, 415 insertions(+), 194 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/Pool.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/Schedulable.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala index 5a4e9a582d..b6d3c2c089 100644 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala @@ -2,6 +2,8 @@ package spark.scheduler import spark.TaskContext +import java.util.Properties + /** * Tracks information about an active job in the DAGScheduler. */ @@ -11,7 +13,8 @@ private[spark] class ActiveJob( val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: String, - val listener: JobListener) { + val listener: JobListener, + val properties: Properties) { val numPartitions = partitions.length val finished = Array.fill[Boolean](numPartitions)(false) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 2ad73f3232..717cc27739 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -89,6 +89,8 @@ class DAGScheduler( // stray messages to detect. val failedGeneration = new HashMap[String, Long] + val idToActiveJob = new HashMap[Int, ActiveJob] + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures @@ -129,11 +131,11 @@ class DAGScheduler( * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int, properties: Properties): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority, properties) + val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -144,7 +146,7 @@ class DAGScheduler( * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int, properties: Properties): Stage = { + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { if (shuffleDep != None) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -152,7 +154,7 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority, properties), priority, properties) + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) idToStage(id) = stage stageToInfos(stage) = StageInfo(stage) stage @@ -162,7 +164,7 @@ class DAGScheduler( * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - private def getParentStages(rdd: RDD[_], priority: Int, properties: Properties): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -173,7 +175,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - parents += getShuffleMapStage(shufDep, priority, properties) + parents += getShuffleMapStage(shufDep, priority) case _ => visit(dep.rdd) } @@ -194,7 +196,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) + val mapStage = getShuffleMapStage(shufDep, stage.priority) if (!mapStage.isAvailable) { missing += mapStage } @@ -239,7 +241,8 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit, properties: Properties = null) + resultHandler: (Int, U) => Unit, + properties: Properties = null) { if (partitions.size == 0) { return @@ -260,7 +263,8 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], callSite: String, - timeout: Long, properties: Properties = null) + timeout: Long, + properties: Properties = null) : PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) @@ -278,8 +282,8 @@ class DAGScheduler( event match { case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId, properties) - val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) + val finalStage = newStage(finalRDD, None, runId) + val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") @@ -290,6 +294,7 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { + idToActiveJob(runId) = job activeJobs += job resultStageToJob(finalStage) = job submitStage(finalStage) @@ -459,8 +464,9 @@ class DAGScheduler( logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) + val properties = idToActiveJob(stage.priority).properties taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, stage.properties)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, properties)) if (!stage.submissionTime.isDefined) { stage.submissionTime = Some(System.currentTimeMillis()) } @@ -665,7 +671,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) + val mapStage = getShuffleMapStage(shufDep, stage.priority) if (!mapStage.isAvailable) { visitedStages += mapStage visit(mapStage.rdd) diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 97afa27a60..bc54cd601d 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -1,7 +1,6 @@ package spark.scheduler import java.net.URI -import java.util.Properties import spark._ import spark.storage.BlockManagerId @@ -26,8 +25,7 @@ private[spark] class Stage( val rdd: RDD[_], val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], - val priority: Int, - val properties: Properties = null) + val priority: Int) extends Logging { val isShuffleMap = shuffleDep != None diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5e960eb59d..092b0a0cfc 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -143,7 +143,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) * that tasks are balanced across the cluster. */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { - synchronized { + synchronized { + SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { @@ -152,25 +153,33 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHost(o.hostname) = new HashSet() } } - // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val taskSetIds = taskSetQueuesManager.receiveOffer(tasks, offers) - //We populate the necessary bookkeeping structures - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).hostname - for(j <- 0 until tasks(i).size) { - val tid = tasks(i)(j).taskId - val taskSetid = taskSetIds(i)(j) - taskIdToTaskSetId(tid) = taskSetid - taskSetTaskIds(taskSetid) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - } + val availableCpus = offers.map(o => o.cores).toArray + for (i <- 0 until offers.size) + { + var launchedTask = true + val execId = offers(i).executorId + val host = offers(i).hostname + while (availableCpus(i) > 0 && launchedTask) + { + launchedTask = false + taskSetQueuesManager.receiveOffer(execId,host,availableCpus(i)) match { + case Some(task) => + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = task.taskSetId + taskSetTaskIds(task.taskSetId) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + + case None => {} + } + } } - if (tasks.size > 0) { hasLaunchedTask = true } @@ -219,10 +228,11 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetToUpdate.get.statusUpdate(tid, state, serializedData) } if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + listener.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost backend.reviveOffers() } @@ -289,7 +299,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Call listener.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + listener.executorLost(failedExecutor.get) backend.reviveOffers() } } diff --git a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala index 99a9c94222..868b11c8d6 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala @@ -10,6 +10,7 @@ import spark.Logging private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with Logging { var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + val tasksetSchedulingAlgorithm = new FIFOSchedulingAlgorithm() override def addTaskSetManager(manager: TaskSetManager) { activeTaskSetsQueue += manager @@ -27,31 +28,19 @@ private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) } - override def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] = { - val taskSetIds = offers.map(o => new ArrayBuffer[String](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - var launchedTask = false - for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { - case Some(task) => - tasks(i) += task - taskSetIds(i) += manager.taskSet.id - availableCpus(i) -= 1 - launchedTask = true - - case None => {} - } - } - } while (launchedTask) + override def receiveOffer(execId:String, host:String,avaiableCpus:Double):Option[TaskDescription] = + { + for(manager <- activeTaskSetsQueue.sortWith(tasksetSchedulingAlgorithm.comparator)) + { + val task = manager.slaveOffer(execId,host,avaiableCpus) + if (task != None) + { + return task + } } - return taskSetIds + return None } - + override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false for (ts <- activeTaskSetsQueue) { @@ -60,4 +49,4 @@ private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with return shouldRevive } -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala index ca308a5229..4e26cedfda 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala @@ -10,174 +10,166 @@ import scala.util.control.Breaks._ import scala.xml._ import spark.Logging +import spark.scheduler.cluster.SchedulingMode.SchedulingMode /** * A Fair Implementation of the TaskSetQueuesManager * - * The current implementation makes the following assumptions: A pool has a fixed configuration of weight. - * Within a pool, it just uses FIFO. - * Also, currently we assume that pools are statically defined - * We currently don't support min shares + * Currently we support minShare,weight for fair scheduler between pools + * Within a pool, it supports FIFO or FS + * Also, currently we could allocate pools dynamically + * */ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with Logging { val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") val poolNameToPool= new HashMap[String, Pool] var pools = new ArrayBuffer[Pool] + val poolScheduleAlgorithm = new FairSchedulingAlgorithm() + val POOL_FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" + val POOL_DEFAULT_POOL_NAME = "default" + val POOL_MINIMUM_SHARES_PROPERTY = "minShares" + val POOL_SCHEDULING_MODE_PROPERTY = "schedulingMode" + val POOL_WEIGHT_PROPERTY = "weight" + val POOL_POOL_NAME_PROPERTY = "@name" + val POOL_POOLS_PROPERTY = "pool" + val POOL_DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO + val POOL_DEFAULT_MINIMUM_SHARES = 2 + val POOL_DEFAULT_WEIGHT = 1 loadPoolProperties() def loadPoolProperties() { //first check if the file exists val file = new File(schedulerAllocFile) - if(!file.exists()) { - //if file does not exist, we just create 1 pool, default - val pool = new Pool("default",100) - pools += pool - poolNameToPool("default") = pool - logInfo("Created a default pool with weight = 100") - } - else { + if(file.exists()) + { val xml = XML.loadFile(file) - for (poolNode <- (xml \\ "pool")) { - if((poolNode \ "weight").text != ""){ - val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt) - pools += pool - poolNameToPool((poolNode \ "@name").text) = pool - logInfo("Created pool "+ pool.name +"with weight = "+pool.weight) - } else { - val pool = new Pool((poolNode \ "@name").text,100) - pools += pool - poolNameToPool((poolNode \ "@name").text) = pool - logInfo("Created pool "+ pool.name +"with weight = 100") + for (poolNode <- (xml \\ POOL_POOLS_PROPERTY)) { + + val poolName = (poolNode \ POOL_POOL_NAME_PROPERTY).text + var schedulingMode = POOL_DEFAULT_SCHEDULING_MODE + var minShares = POOL_DEFAULT_MINIMUM_SHARES + var weight = POOL_DEFAULT_WEIGHT + + + val xmlSchedulingMode = (poolNode \ POOL_SCHEDULING_MODE_PROPERTY).text + if( xmlSchedulingMode != "") + { + try + { + schedulingMode = SchedulingMode.withName(xmlSchedulingMode) + } + catch{ + case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + } } - } - if(!poolNameToPool.contains("default")) { - val pool = new Pool("default", 100) + + val xmlMinShares = (poolNode \ POOL_MINIMUM_SHARES_PROPERTY).text + if(xmlMinShares != "") + { + minShares = xmlMinShares.toInt + } + + val xmlWeight = (poolNode \ POOL_WEIGHT_PROPERTY).text + if(xmlWeight != "") + { + weight = xmlWeight.toInt + } + + val pool = new Pool(poolName,schedulingMode,minShares,weight) pools += pool - poolNameToPool("default") = pool - logInfo("Created a default pool with weight = 100") + poolNameToPool(poolName) = pool + logInfo("Create new pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,schedulingMode,minShares,weight)) } - - } + } + + if(!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) + { + val pool = new Pool(POOL_DEFAULT_POOL_NAME, POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) + pools += pool + poolNameToPool(POOL_DEFAULT_POOL_NAME) = pool + logInfo("Create default pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(POOL_DEFAULT_POOL_NAME,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) + } } override def addTaskSetManager(manager: TaskSetManager) { - var poolName = "default" - if(manager.taskSet.properties != null) - poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - if(poolNameToPool.contains(poolName)) - poolNameToPool(poolName).activeTaskSetsQueue += manager - else - poolNameToPool("default").activeTaskSetsQueue += manager - logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) - + var poolName = POOL_DEFAULT_POOL_NAME + if(manager.taskSet.properties != null) + { + poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) + if(!poolNameToPool.contains(poolName)) + { + //we will create a new pool that user has configured in app,but not contained in xml file + val pool = new Pool(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) + pools += pool + poolNameToPool(poolName) = pool + logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) + } + } + poolNameToPool(poolName).addTaskSetManager(manager) + logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) } override def removeTaskSetManager(manager: TaskSetManager) { - var poolName = "default" - if(manager.taskSet.properties != null) - poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - if(poolNameToPool.contains(poolName)) - poolNameToPool(poolName).activeTaskSetsQueue -= manager - else - poolNameToPool("default").activeTaskSetsQueue -= manager + + var poolName = POOL_DEFAULT_POOL_NAME + if(manager.taskSet.properties != null) + { + poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) + } + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id,poolName)) + val pool = poolNameToPool(poolName) + pool.removeTaskSetManager(manager) + pool.setRunningTasks(pool.getRunningTasks() - manager.getRunningTasks()) + } override def taskFinished(manager: TaskSetManager) { - var poolName = "default" - if(manager.taskSet.properties != null) - poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") - if(poolNameToPool.contains(poolName)) - poolNameToPool(poolName).numRunningTasks -= 1 - else - poolNameToPool("default").numRunningTasks -= 1 + var poolName = POOL_DEFAULT_POOL_NAME + if(manager.taskSet.properties != null) + { + poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) + } + val pool = poolNameToPool(poolName) + pool.setRunningTasks(pool.getRunningTasks() - 1) + manager.setRunningTasks(manager.getRunningTasks() - 1) } override def removeExecutor(executorId: String, host: String) { for (pool <- pools) { - pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + pool.removeExecutor(executorId,host) } } - /** - * This is the comparison function used for sorting to determine which - * pool to allocate next based on fairness. - * The algorithm is as follows: we sort by the pool's running tasks to weight ratio - * (pools number running tast / pool's weight) - */ - def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = { - val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble - val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble - var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2) - if (res == 0) { - //Jobs are tied in fairness ratio. We break the tie by name - res = pool1.name.compareTo(pool2.name) - } - if (res < 0) - return true - else - return false + override def receiveOffer(execId: String,host:String,avaiableCpus:Double):Option[TaskDescription] = + { + + val sortedPools = pools.sortWith(poolScheduleAlgorithm.comparator) + for(pool <- sortedPools) + { + logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format(pool.poolName,pool.activeTaskSetsQueue.length,pool.getMinShare(),pool.getRunningTasks())) + } + for (pool <- sortedPools) + { + val task = pool.receiveOffer(execId,host,avaiableCpus) + if(task != None) + { + pool.setRunningTasks(pool.getRunningTasks() + 1) + return task + } + } + return None } - override def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] = { - val taskSetIds = offers.map(o => new ArrayBuffer[String](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - var launchedTask = false - - for (i <- 0 until offers.size) { //we loop through the list of offers - val execId = offers(i).executorId - val host = offers(i).hostname - var breakOut = false - while(availableCpus(i) > 0 && !breakOut) { - breakable{ - launchedTask = false - for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools - if(!pool.activeTaskSetsQueue.isEmpty) { - //sort the tasksetmanager in the pool - pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId)) - for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool - //Make an offer - manager.slaveOffer(execId, host, availableCpus(i)) match { - case Some(task) => - tasks(i) += task - taskSetIds(i) += manager.taskSet.id - availableCpus(i) -= 1 - pool.numRunningTasks += 1 - launchedTask = true - logInfo("launched task for pool"+pool.name); - break - case None => {} - } - } - } - } - //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer - if(!launchedTask){ - breakOut = true - } - } - } - } - return taskSetIds - } - - override def checkSpeculatableTasks(): Boolean = { + override def checkSpeculatableTasks(): Boolean = + { var shouldRevive = false - for (pool <- pools) { - for (ts <- pool.activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } + for (pool <- pools) + { + shouldRevive |= pool.checkSpeculatableTasks() } return shouldRevive } -} -/** - * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight - */ -class Pool(val name: String, val weight: Int) -{ - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] - var numRunningTasks: Int = 0 -} \ No newline at end of file + } diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala new file mode 100644 index 0000000000..7b58a99582 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -0,0 +1,92 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer + +import spark.Logging +import spark.scheduler.cluster.SchedulingMode.SchedulingMode +/** + * An interface for + * + */ +private[spark] class Pool(val poolName: String, schedulingMode: SchedulingMode,val minShare:Int, val weight:Int) extends Schedulable with Logging { + + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + var numRunningTasks: Int = 0 + var taskSetSchedulingAlgorithm: SchedulingAlgorithm = + { + schedulingMode match + { + case SchedulingMode.FAIR => + val schedule = new FairSchedulingAlgorithm() + schedule + case SchedulingMode.FIFO => + val schedule = new FIFOSchedulingAlgorithm() + schedule + } + } + + override def getMinShare():Int = + { + return minShare + } + + override def getRunningTasks():Int = + { + return numRunningTasks + } + + def setRunningTasks(taskNum : Int) + { + numRunningTasks = taskNum + } + + override def getWeight(): Int = + { + return weight + } + + def addTaskSetManager(manager:TaskSetManager) + { + activeTaskSetsQueue += manager + } + + def removeTaskSetManager(manager:TaskSetManager) + { + activeTaskSetsQueue -= manager + } + + def removeExecutor(executorId: String, host: String) + { + activeTaskSetsQueue.foreach(_.executorLost(executorId,host)) + } + + def checkSpeculatableTasks(): Boolean = + { + var shouldRevive = false + for(ts <- activeTaskSetsQueue) + { + shouldRevive |= ts.checkSpeculatableTasks() + } + return shouldRevive + } + + def receiveOffer(execId:String,host:String,availableCpus:Double):Option[TaskDescription] = + { + val sortedActiveTasksSetQueue = activeTaskSetsQueue.sortWith(taskSetSchedulingAlgorithm.comparator) + for(manager <- sortedActiveTasksSetQueue) + { + + logDebug("taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format(manager.taskSet.id,manager.numTasks,manager.getMinShare(),manager.getWeight(),manager.getRunningTasks())) + } + for(manager <- sortedActiveTasksSetQueue) + { + val task = manager.slaveOffer(execId,host,availableCpus) + if (task != None) + { + manager.setRunningTasks(manager.getRunningTasks() + 1) + return task + } + } + return None + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala new file mode 100644 index 0000000000..837f9c4983 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -0,0 +1,21 @@ +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer + +/** + * An interface for schedulable entities, there are two type Schedulable entities(Pools and TaskSetManagers) + */ +private[spark] trait Schedulable { + + def getMinShare(): Int + def getRunningTasks(): Int + def getPriority(): Int = + { + return 0 + } + def getWeight(): Int + def getStageId(): Int = + { + return 0 + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala new file mode 100644 index 0000000000..f8919e7374 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -0,0 +1,69 @@ +package spark.scheduler.cluster + +/** + * An interface for sort algorithm + * FIFO: FIFO algorithm for TaskSetManagers + * FS: FS algorithm for Pools, and FIFO or FS for TaskSetManagers + */ +private[spark] trait SchedulingAlgorithm { + def comparator(s1: Schedulable,s2: Schedulable): Boolean +} + +private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm +{ + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = + { + val priority1 = s1.getPriority() + val priority2 = s2.getPriority() + var res = Math.signum(priority1 - priority2) + if (res == 0) + { + val stageId1 = s1.getStageId() + val stageId2 = s2.getStageId() + res = Math.signum(stageId1 - stageId2) + } + if (res < 0) + { + return true + } + else + { + return false + } + } +} + +private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm +{ + def comparator(s1: Schedulable, s2:Schedulable): Boolean = + { + val minShare1 = s1.getMinShare() + val minShare2 = s2.getMinShare() + val s1Needy = s1.getRunningTasks() < minShare1 + val s2Needy = s2.getRunningTasks() < minShare2 + val minShareRatio1 = s1.getRunningTasks().toDouble / Math.max(minShare1,1.0).toDouble + val minShareRatio2 = s2.getRunningTasks().toDouble / Math.max(minShare2,1.0).toDouble + val taskToWeightRatio1 = s1.getRunningTasks().toDouble / s1.getWeight().toDouble + val taskToWeightRatio2 = s2.getRunningTasks().toDouble / s2.getWeight().toDouble + var res:Boolean = true + + if(s1Needy && !s2Needy) + { + res = true + } + else if(!s1Needy && s2Needy) + { + res = false + } + else if (s1Needy && s2Needy) + { + res = minShareRatio1 <= minShareRatio2 + } + else + { + res = taskToWeightRatio1 <= taskToWeightRatio2 + } + return res + } +} + diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala new file mode 100644 index 0000000000..6be4f3cd84 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala @@ -0,0 +1,8 @@ +package spark.scheduler.cluster + +object SchedulingMode extends Enumeration("FAIR","FIFO") +{ + type SchedulingMode = Value + + val FAIR,FIFO = Value +} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index b41e951be9..cdd004c94b 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -5,6 +5,7 @@ import spark.util.SerializableBuffer private[spark] class TaskDescription( val taskId: Long, + val taskSetId: String, val executorId: String, val name: String, _serializedTask: ByteBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 015092b60b..723c3b46bd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { +private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Schedulable with Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong @@ -28,6 +28,9 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 + val TASKSET_MINIMIUM_SHARES = 1 + + val TASKSET_WEIGHT = 1 // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble @@ -43,6 +46,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksFinished = 0 + var numRunningTasks =0; // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis @@ -96,6 +100,36 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } + override def getMinShare(): Int = + { + return TASKSET_MINIMIUM_SHARES + } + + override def getRunningTasks(): Int = + { + return numRunningTasks + } + + def setRunningTasks(taskNum :Int) + { + numRunningTasks = taskNum + } + + override def getPriority(): Int = + { + return priority + } + + override def getWeight(): Int = + { + return TASKSET_WEIGHT + } + + override def getStageId(): Int = + { + return taskSet.stageId + } + // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive @@ -222,7 +256,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + return Some(new TaskDescription(taskId,taskSet.id,execId, taskName, serializedTask)) } case _ => } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala index b0c30e9e8b..c117ee7a85 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala @@ -12,8 +12,6 @@ private[spark] trait TaskSetQueuesManager { def removeTaskSetManager(manager: TaskSetManager): Unit def taskFinished(manager: TaskSetManager): Unit def removeExecutor(executorId: String, host: String): Unit - //The receiveOffers function, accepts tasks and offers. It populates the tasks to the actual task from TaskSet - //It returns a list of TaskSet ID that corresponds to each assigned tasks - def receiveOffer(tasks: Seq[ArrayBuffer[TaskDescription]], offers: Seq[WorkerOffer]): Seq[Seq[String]] + def receiveOffer(execId: String, host:String, avaiableCpus:Double):Option[TaskDescription] def checkSpeculatableTasks(): Boolean -} \ No newline at end of file +} -- cgit v1.2.3 From d61978d0abad30a148680c8a63df33e40e469525 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 15 Mar 2013 23:36:52 -0600 Subject: keeping JavaStreamingContext in sync with StreamingContext + adding comments for better clarity --- .../main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 7 +++---- .../src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 2373f4824a..7a8864614c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -80,6 +80,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed @@ -87,16 +88,14 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( - zkQuorum: String, - groupId: String, + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - zkQuorum, - groupId, + kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index c6da1a7f70..85693808d1 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -100,8 +100,10 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Handles cleanup of consumer group znode. Lifted with love from Kafka's - // ConsumerConsole.scala tryCleanupZookeeper() + // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { val dir = "/consumers/" + groupId -- cgit v1.2.3 From d1d9bdaabe24cc60097f843e0bef92e57b404941 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Sat, 23 Mar 2013 07:25:30 +0800 Subject: Just update typo and comments --- core/src/main/scala/spark/scheduler/cluster/Pool.scala | 3 +-- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index 7b58a99582..68e1d2a75a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer import spark.Logging import spark.scheduler.cluster.SchedulingMode.SchedulingMode /** - * An interface for - * + * An Schedulable entity that represent collection of TaskSetManager */ private[spark] class Pool(val poolName: String, schedulingMode: SchedulingMode,val minShare:Int, val weight:Int) extends Schedulable with Logging { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 723c3b46bd..064593f486 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -28,7 +28,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 - val TASKSET_MINIMIUM_SHARES = 1 + val TASKSET_MINIMUM_SHARES = 1 val TASKSET_WEIGHT = 1 // Quantile of tasks at which to start speculation @@ -102,7 +102,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe override def getMinShare(): Int = { - return TASKSET_MINIMIUM_SHARES + return TASKSET_MINIMUM_SHARES } override def getRunningTasks(): Int = -- cgit v1.2.3 From 329ef34c2e04d28c2ad150cf6674d6e86d7511ce Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 26 Mar 2013 23:56:15 -0600 Subject: fixing autooffset.reset behavior when set to 'largest' --- .../main/scala/spark/streaming/dstream/KafkaInputDStream.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 85693808d1..17a5be3420 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -75,9 +75,9 @@ class KafkaReceiver(kafkaParams: Map[String, String], consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + kafkaParams("zk.connect")) - // When autooffset.reset is 'smallest', it is our responsibility to try and whack the + // When autooffset.reset is defined, it is our responsibility to try and whack the // consumer group zk node. - if (kafkaParams.get("autooffset.reset").exists(_ == "smallest")) { + if (kafkaParams.contains("autooffset.reset")) { tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid")) } @@ -100,9 +100,11 @@ class KafkaReceiver(kafkaParams: Map[String, String], } } - // Delete consumer group from zookeeper. This effectivly resets the group so we can consume from the beginning again. + // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because + // Kafka 0.7.2 only honors this param when the group is not in zookeeper. + // // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas' - // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest': + // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest': // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { try { -- cgit v1.2.3 From def3d1c84a3e0d1371239e9358294a4b4ad46b9f Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 29 Mar 2013 08:20:35 +0800 Subject: 1.remove redundant spacing in source code 2.replace get/set functions with val and var defination --- .../src/main/scala/spark/scheduler/ActiveJob.scala | 2 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 4 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 13 ++--- .../cluster/FIFOTaskSetQueuesManager.scala | 13 ++--- .../cluster/FairTaskSetQueuesManager.scala | 65 +++++++++++----------- .../main/scala/spark/scheduler/cluster/Pool.scala | 43 +++++--------- .../spark/scheduler/cluster/Schedulable.scala | 21 ++----- .../scheduler/cluster/SchedulingAlgorithm.scala | 34 +++++------ .../spark/scheduler/cluster/SchedulingMode.scala | 2 +- .../spark/scheduler/cluster/TaskSetManager.scala | 37 ++---------- .../scheduler/cluster/TaskSetQueuesManager.scala | 1 - 11 files changed, 91 insertions(+), 144 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala index b6d3c2c089..105eaecb22 100644 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala @@ -13,7 +13,7 @@ private[spark] class ActiveJob( val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], val callSite: String, - val listener: JobListener, + val listener: JobListener, val properties: Properties) { val numPartitions = partitions.length diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 717cc27739..0a64a4f041 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -241,7 +241,7 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit, + resultHandler: (Int, U) => Unit, properties: Properties = null) { if (partitions.size == 0) { @@ -263,7 +263,7 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], callSite: String, - timeout: Long, + timeout: Long, properties: Properties = null) : PartialResult[R] = { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 092b0a0cfc..be0d480aa0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -60,7 +60,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var backend: SchedulerBackend = null val mapOutputTracker = SparkEnv.get.mapOutputTracker - + var taskSetQueuesManager: TaskSetQueuesManager = null override def setListener(listener: TaskSchedulerListener) { @@ -131,11 +131,11 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def taskFinished(manager: TaskSetManager) { + def taskFinished(manager: TaskSetManager) { this.synchronized { - taskSetQueuesManager.taskFinished(manager) + taskSetQueuesManager.taskFinished(manager) } - } + } /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task @@ -144,7 +144,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { synchronized { - SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { @@ -228,7 +227,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetToUpdate.get.statusUpdate(tid, state, serializedData) } if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + listener.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -299,7 +298,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Call listener.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + listener.executorLost(failedExecutor.get) backend.reviveOffers() } } diff --git a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala index 868b11c8d6..5949ee773f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala @@ -8,26 +8,26 @@ import spark.Logging * A FIFO Implementation of the TaskSetQueuesManager */ private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with Logging { - + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val tasksetSchedulingAlgorithm = new FIFOSchedulingAlgorithm() - + override def addTaskSetManager(manager: TaskSetManager) { activeTaskSetsQueue += manager } - + override def removeTaskSetManager(manager: TaskSetManager) { activeTaskSetsQueue -= manager } - + override def taskFinished(manager: TaskSetManager) { //do nothing } - + override def removeExecutor(executorId: String, host: String) { activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) } - + override def receiveOffer(execId:String, host:String,avaiableCpus:Double):Option[TaskDescription] = { for(manager <- activeTaskSetsQueue.sortWith(tasksetSchedulingAlgorithm.comparator)) @@ -48,5 +48,4 @@ private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with } return shouldRevive } - } diff --git a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala index 4e26cedfda..0609600f35 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala @@ -14,15 +14,14 @@ import spark.scheduler.cluster.SchedulingMode.SchedulingMode /** * A Fair Implementation of the TaskSetQueuesManager - * + * * Currently we support minShare,weight for fair scheduler between pools * Within a pool, it supports FIFO or FS * Also, currently we could allocate pools dynamically - * */ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with Logging { - - val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") + + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") val poolNameToPool= new HashMap[String, Pool] var pools = new ArrayBuffer[Pool] val poolScheduleAlgorithm = new FairSchedulingAlgorithm() @@ -36,9 +35,9 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with val POOL_DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO val POOL_DEFAULT_MINIMUM_SHARES = 2 val POOL_DEFAULT_WEIGHT = 1 - + loadPoolProperties() - + def loadPoolProperties() { //first check if the file exists val file = new File(schedulerAllocFile) @@ -51,26 +50,25 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with var schedulingMode = POOL_DEFAULT_SCHEDULING_MODE var minShares = POOL_DEFAULT_MINIMUM_SHARES var weight = POOL_DEFAULT_WEIGHT - - + val xmlSchedulingMode = (poolNode \ POOL_SCHEDULING_MODE_PROPERTY).text if( xmlSchedulingMode != "") { - try + try { schedulingMode = SchedulingMode.withName(xmlSchedulingMode) } catch{ - case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") } } - + val xmlMinShares = (poolNode \ POOL_MINIMUM_SHARES_PROPERTY).text if(xmlMinShares != "") { minShares = xmlMinShares.toInt } - + val xmlWeight = (poolNode \ POOL_WEIGHT_PROPERTY).text if(xmlWeight != "") { @@ -84,15 +82,15 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with } } - if(!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) + if(!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) { val pool = new Pool(POOL_DEFAULT_POOL_NAME, POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) pools += pool poolNameToPool(POOL_DEFAULT_POOL_NAME) = pool logInfo("Create default pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(POOL_DEFAULT_POOL_NAME,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) - } + } } - + override def addTaskSetManager(manager: TaskSetManager) { var poolName = POOL_DEFAULT_POOL_NAME if(manager.taskSet.properties != null) @@ -100,19 +98,19 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) if(!poolNameToPool.contains(poolName)) { - //we will create a new pool that user has configured in app,but not contained in xml file + //we will create a new pool that user has configured in app instead of being defined in xml file val pool = new Pool(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) pools += pool poolNameToPool(poolName) = pool - logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) + logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) } } poolNameToPool(poolName).addTaskSetManager(manager) - logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) + logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) } - + override def removeTaskSetManager(manager: TaskSetManager) { - + var poolName = POOL_DEFAULT_POOL_NAME if(manager.taskSet.properties != null) { @@ -121,10 +119,9 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id,poolName)) val pool = poolNameToPool(poolName) pool.removeTaskSetManager(manager) - pool.setRunningTasks(pool.getRunningTasks() - manager.getRunningTasks()) - + pool.runningTasks -= manager.runningTasks } - + override def taskFinished(manager: TaskSetManager) { var poolName = POOL_DEFAULT_POOL_NAME if(manager.taskSet.properties != null) @@ -132,40 +129,40 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) } val pool = poolNameToPool(poolName) - pool.setRunningTasks(pool.getRunningTasks() - 1) - manager.setRunningTasks(manager.getRunningTasks() - 1) + pool.runningTasks -= 1 + manager.runningTasks -=1 } - + override def removeExecutor(executorId: String, host: String) { for (pool <- pools) { - pool.removeExecutor(executorId,host) - } + pool.removeExecutor(executorId,host) + } } - + override def receiveOffer(execId: String,host:String,avaiableCpus:Double):Option[TaskDescription] = { val sortedPools = pools.sortWith(poolScheduleAlgorithm.comparator) for(pool <- sortedPools) { - logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format(pool.poolName,pool.activeTaskSetsQueue.length,pool.getMinShare(),pool.getRunningTasks())) + logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format(pool.poolName,pool.activeTaskSetsQueue.length,pool.minShare,pool.runningTasks)) } for (pool <- sortedPools) { val task = pool.receiveOffer(execId,host,avaiableCpus) if(task != None) { - pool.setRunningTasks(pool.getRunningTasks() + 1) + pool.runningTasks += 1 return task } } return None } - - override def checkSpeculatableTasks(): Boolean = + + override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (pool <- pools) + for (pool <- pools) { shouldRevive |= pool.checkSpeculatableTasks() } diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index 68e1d2a75a..8fdca5d2b4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -7,13 +7,21 @@ import spark.scheduler.cluster.SchedulingMode.SchedulingMode /** * An Schedulable entity that represent collection of TaskSetManager */ -private[spark] class Pool(val poolName: String, schedulingMode: SchedulingMode,val minShare:Int, val weight:Int) extends Schedulable with Logging { - +private[spark] class Pool(val poolName: String,val schedulingMode: SchedulingMode, initMinShare:Int, initWeight:Int) extends Schedulable with Logging +{ + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] - var numRunningTasks: Int = 0 - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = + + var weight = initWeight + var minShare = initMinShare + var runningTasks = 0 + + val priority = 0 + val stageId = 0 + + var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { - schedulingMode match + schedulingMode match { case SchedulingMode.FAIR => val schedule = new FairSchedulingAlgorithm() @@ -23,26 +31,6 @@ private[spark] class Pool(val poolName: String, schedulingMode: SchedulingMode,v schedule } } - - override def getMinShare():Int = - { - return minShare - } - - override def getRunningTasks():Int = - { - return numRunningTasks - } - - def setRunningTasks(taskNum : Int) - { - numRunningTasks = taskNum - } - - override def getWeight(): Int = - { - return weight - } def addTaskSetManager(manager:TaskSetManager) { @@ -74,15 +62,14 @@ private[spark] class Pool(val poolName: String, schedulingMode: SchedulingMode,v val sortedActiveTasksSetQueue = activeTaskSetsQueue.sortWith(taskSetSchedulingAlgorithm.comparator) for(manager <- sortedActiveTasksSetQueue) { - - logDebug("taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format(manager.taskSet.id,manager.numTasks,manager.getMinShare(),manager.getWeight(),manager.getRunningTasks())) + logDebug("poolname:%s,taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format(poolName,manager.taskSet.id,manager.numTasks,manager.minShare,manager.weight,manager.runningTasks)) } for(manager <- sortedActiveTasksSetQueue) { val task = manager.slaveOffer(execId,host,availableCpus) if (task != None) { - manager.setRunningTasks(manager.getRunningTasks() + 1) + manager.runningTasks += 1 return task } } diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala index 837f9c4983..6f4f104f42 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -1,21 +1,12 @@ package spark.scheduler.cluster -import scala.collection.mutable.ArrayBuffer - /** - * An interface for schedulable entities, there are two type Schedulable entities(Pools and TaskSetManagers) + * An interface for schedulable entities, there are two type Schedulable entities(Pools and TaskSetManagers) */ private[spark] trait Schedulable { - - def getMinShare(): Int - def getRunningTasks(): Int - def getPriority(): Int = - { - return 0 - } - def getWeight(): Int - def getStageId(): Int = - { - return 0 - } + def weight:Int + def minShare:Int + def runningTasks:Int + def priority:Int + def stageId:Int } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index f8919e7374..2f8123587f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -1,7 +1,7 @@ package spark.scheduler.cluster /** - * An interface for sort algorithm + * An interface for sort algorithm * FIFO: FIFO algorithm for TaskSetManagers * FS: FS algorithm for Pools, and FIFO or FS for TaskSetManagers */ @@ -13,13 +13,13 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { - val priority1 = s1.getPriority() - val priority2 = s2.getPriority() + val priority1 = s1.priority + val priority2 = s2.priority var res = Math.signum(priority1 - priority2) if (res == 0) { - val stageId1 = s1.getStageId() - val stageId2 = s2.getStageId() + val stageId1 = s1.stageId + val stageId2 = s2.stageId res = Math.signum(stageId1 - stageId2) } if (res < 0) @@ -29,7 +29,7 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm else { return false - } + } } } @@ -37,16 +37,18 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { def comparator(s1: Schedulable, s2:Schedulable): Boolean = { - val minShare1 = s1.getMinShare() - val minShare2 = s2.getMinShare() - val s1Needy = s1.getRunningTasks() < minShare1 - val s2Needy = s2.getRunningTasks() < minShare2 - val minShareRatio1 = s1.getRunningTasks().toDouble / Math.max(minShare1,1.0).toDouble - val minShareRatio2 = s2.getRunningTasks().toDouble / Math.max(minShare2,1.0).toDouble - val taskToWeightRatio1 = s1.getRunningTasks().toDouble / s1.getWeight().toDouble - val taskToWeightRatio2 = s2.getRunningTasks().toDouble / s2.getWeight().toDouble + val minShare1 = s1.minShare + val minShare2 = s2.minShare + val runningTasks1 = s1.runningTasks + val runningTasks2 = s2.runningTasks + val s1Needy = runningTasks1 < minShare1 + val s2Needy = runningTasks2 < minShare2 + val minShareRatio1 = runningTasks1.toDouble / Math.max(minShare1,1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / Math.max(minShare2,1.0).toDouble + val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble + val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true - + if(s1Needy && !s2Needy) { res = true @@ -57,7 +59,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm } else if (s1Needy && s2Needy) { - res = minShareRatio1 <= minShareRatio2 + res = minShareRatio1 <= minShareRatio2 } else { diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala index 6be4f3cd84..480af2c1a3 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala @@ -1,4 +1,4 @@ -package spark.scheduler.cluster +package spark.scheduler.cluster object SchedulingMode extends Enumeration("FAIR","FIFO") { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 064593f486..ddc4fa6642 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -29,7 +29,6 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val MAX_TASK_FAILURES = 4 val TASKSET_MINIMUM_SHARES = 1 - val TASKSET_WEIGHT = 1 // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble @@ -38,7 +37,12 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Serializer for closures and tasks. val ser = SparkEnv.get.closureSerializer.newInstance() + var weight = TASKSET_WEIGHT + var minShare = TASKSET_MINIMUM_SHARES + var runningTasks = 0 val priority = taskSet.priority + val stageId = taskSet.stageId + val tasks = taskSet.tasks val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -46,7 +50,6 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksFinished = 0 - var numRunningTasks =0; // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis @@ -100,36 +103,6 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } - override def getMinShare(): Int = - { - return TASKSET_MINIMUM_SHARES - } - - override def getRunningTasks(): Int = - { - return numRunningTasks - } - - def setRunningTasks(taskNum :Int) - { - numRunningTasks = taskNum - } - - override def getPriority(): Int = - { - return priority - } - - override def getWeight(): Int = - { - return TASKSET_WEIGHT - } - - override def getStageId(): Int = - { - return taskSet.stageId - } - // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala index c117ee7a85..86971d47e6 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala @@ -5,7 +5,6 @@ import scala.collection.mutable.ArrayBuffer /** * An interface for managing TaskSet queue/s that allows plugging different policy for * offering tasks to resources - * */ private[spark] trait TaskSetQueuesManager { def addTaskSetManager(manager: TaskSetManager): Unit -- cgit v1.2.3 From 1a28f92711cb59ad99bc9e3dd84a5990181e572b Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 29 Mar 2013 08:34:28 +0800 Subject: change some typo and some spacing --- core/src/main/scala/spark/SparkContext.scala | 12 ++++++------ core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 ++-- core/src/main/scala/spark/scheduler/Stage.scala | 5 ++--- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 6eccb501c7..ed5f686379 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -74,7 +74,7 @@ class SparkContext( if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") } - + //Set the default task scheduler if (System.getProperty("spark.cluster.taskscheduler") == null) { System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.FIFOTaskSetQueuesManager") @@ -119,7 +119,7 @@ class SparkContext( } } executorEnvs ++= environment - + // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -216,14 +216,14 @@ class SparkContext( } private[spark] var checkpointDir: Option[String] = None - + // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new DynamicVariable[Properties](null) - + def initLocalProperties() { localProperties.value = new Properties() } - + def addLocalProperties(key: String, value: String) { if(localProperties.value == null) { localProperties.value = new Properties() @@ -673,7 +673,7 @@ class SparkContext( val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) } - + /** * Run a job that can return approximate results. */ diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 0a64a4f041..abc24c0270 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -333,7 +333,7 @@ class DAGScheduler( submitStage(stage) } } - + /** * Check for waiting or failed stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. @@ -720,7 +720,7 @@ class DAGScheduler( sizeBefore = shuffleToMapStage.size shuffleToMapStage.clearOldValues(cleanupTime) logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - + sizeBefore = pendingTasks.size pendingTasks.clearOldValues(cleanupTime) logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index bc54cd601d..7fc9e13fd9 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -5,7 +5,6 @@ import java.net.URI import spark._ import spark.storage.BlockManagerId - /** * A stage is a set of independent tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run @@ -27,7 +26,7 @@ private[spark] class Stage( val parents: List[Stage], val priority: Int) extends Logging { - + val isShuffleMap = shuffleDep != None val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) @@ -61,7 +60,7 @@ private[spark] class Stage( numAvailableOutputs -= 1 } } - + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { -- cgit v1.2.3 From 2b373dd07a7b3f2906607d910c869e3290ca9d05 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Tue, 2 Apr 2013 12:11:14 +0800 Subject: add properties default value null to fix sbt/sbt test errors --- core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 79588891e7..6f4e5cd83e 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -23,7 +23,7 @@ private[spark] case class JobSubmitted( partitions: Array[Int], allowLocal: Boolean, callSite: String, - listener: JobListener, properties: Properties) + listener: JobListener, properties: Properties = null) extends DAGSchedulerEvent private[spark] case class CompletionEvent( -- cgit v1.2.3 From df47b40b764e25cbd10ce49d7152e1d33f51a263 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Wed, 20 Feb 2013 11:51:13 +0800 Subject: Shuffle Performance fix: Use netty embeded OIO file server instead of ConnectionManager Shuffle Performance Optimization: do not send 0-byte block requests to reduce network messages change reference from io.Source to scala.io.Source to avoid looking into io.netty package Signed-off-by: shane-huang --- .../main/java/spark/network/netty/FileClient.java | 89 +++++++ .../netty/FileClientChannelInitializer.java | 29 +++ .../spark/network/netty/FileClientHandler.java | 38 +++ .../main/java/spark/network/netty/FileServer.java | 59 +++++ .../netty/FileServerChannelInitializer.java | 33 +++ .../spark/network/netty/FileServerHandler.java | 68 ++++++ .../java/spark/network/netty/PathResolver.java | 12 + .../scala/spark/network/netty/FileHeader.scala | 57 +++++ .../scala/spark/network/netty/ShuffleCopier.scala | 88 +++++++ .../scala/spark/network/netty/ShuffleSender.scala | 50 ++++ .../main/scala/spark/storage/BlockManager.scala | 272 +++++++++++++++++---- core/src/main/scala/spark/storage/DiskStore.scala | 51 +++- project/SparkBuild.scala | 3 +- .../scala/spark/streaming/util/RawTextSender.scala | 2 +- 14 files changed, 795 insertions(+), 56 deletions(-) create mode 100644 core/src/main/java/spark/network/netty/FileClient.java create mode 100644 core/src/main/java/spark/network/netty/FileClientChannelInitializer.java create mode 100644 core/src/main/java/spark/network/netty/FileClientHandler.java create mode 100644 core/src/main/java/spark/network/netty/FileServer.java create mode 100644 core/src/main/java/spark/network/netty/FileServerChannelInitializer.java create mode 100644 core/src/main/java/spark/network/netty/FileServerHandler.java create mode 100755 core/src/main/java/spark/network/netty/PathResolver.java create mode 100644 core/src/main/scala/spark/network/netty/FileHeader.scala create mode 100644 core/src/main/scala/spark/network/netty/ShuffleCopier.scala create mode 100644 core/src/main/scala/spark/network/netty/ShuffleSender.scala diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..d0c5081dd2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -0,0 +1,89 @@ +package spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import java.util.Arrays; + +public class FileClient { + + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + + public FileClient(FileClientHandler handler){ + this.handler = handler; + } + + public void init(){ + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new FileClientChannelInitializer(handler)); + } + + public static final class ChannelCloseListener implements ChannelFutureListener { + private FileClient fc = null; + public ChannelCloseListener(FileClient fc){ + this.fc = fc; + } + @Override + public void operationComplete(ChannelFuture future) { + if (fc.bootstrap!=null){ + fc.bootstrap.shutdown(); + fc.bootstrap = null; + } + } + } + + public void connect(String host, int port){ + try { + + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose(){ + try { + channel.closeFuture().sync(); + } catch (InterruptedException e){ + e.printStackTrace(); + } + } + + public void sendRequest(String file){ + //assert(file == null); + //assert(channel == null); + channel.write(file+"\r\n"); + } + + public void close(){ + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } + + +} + + diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..50e5704619 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,29 @@ +package spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.util.CharsetUtil; + +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileClientChannelInitializer extends + ChannelInitializer { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..911c8b32b5 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -0,0 +1,38 @@ +package spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; +import io.netty.util.CharsetUtil; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Logger; + +public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()){ + handle(ctx,in,currentHeader); + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..729e45f0a1 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -0,0 +1,59 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.Channel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; + +/** + * Server that accept the path of a file an echo back its content. + */ +public class FileServer { + + private ServerBootstrap bootstrap = null; + private Channel channel = null; + private PathResolver pResolver; + + public FileServer(PathResolver pResolver){ + this.pResolver = pResolver; + } + + public void run(int port) { + // Configure the server. + bootstrap = new ServerBootstrap(); + try { + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channel = bootstrap.bind(port).sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally{ + bootstrap.shutdown(); + } + } + + public void stop(){ + if (channel!=null){ + channel.close(); + } + if (bootstrap != null){ + bootstrap.shutdown(); + bootstrap = null; + } + } +} + + diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..9d0618ff1c --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,33 @@ +package spark.network.netty; + +import java.io.File; +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.util.CharsetUtil; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.logging.LogLevel; + +public class FileServerChannelInitializer extends + ChannelInitializer { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder( + 8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + + } +} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..e1083e87a2 --- /dev/null +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -0,0 +1,68 @@ +package spark.network.netty; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; +import io.netty.handler.stream.ChunkedFile; +import java.io.File; +import java.io.FileInputStream; + +public class FileServerHandler extends + ChannelInboundMessageHandlerAdapter { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0 ) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + // TODO Auto-generated catch block + //logger.warning("Exception when sending file : " + //+ file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..5d5eda006e --- /dev/null +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -0,0 +1,12 @@ +package spark.network.netty; + +public interface PathResolver { + /** + * Get the absolute path of the file + * + * @param fileId + * @return the absolute path of file + */ + public String getAbsolutePath(String fileId); + +} diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala new file mode 100644 index 0000000000..aed4254234 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/FileHeader.scala @@ -0,0 +1,57 @@ +package spark.network.netty + +import io.netty.buffer._ + +import spark.Logging + +private[spark] class FileHeader ( + val fileLen: Int, + val blockId: String) extends Logging { + + lazy val buffer = { + val buf = Unpooled.buffer() + buf.capacity(FileHeader.HEADER_SIZE) + buf.writeInt(fileLen) + buf.writeInt(blockId.length) + blockId.foreach((x: Char) => buf.writeByte(x)) + //padding the rest of header + if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { + buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) + } else { + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") + } + buf + } + +} + +private[spark] object FileHeader { + + val HEADER_SIZE = 40 + + def getFileLenOffset = 0 + def getFileLenSize = Integer.SIZE/8 + + def create(buf: ByteBuf): FileHeader = { + val length = buf.readInt + val idLength = buf.readInt + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buf.readByte().asInstanceOf[Char] + } + val blockId = idBuilder.toString() + new FileHeader(length, blockId) + } + + + def main (args:Array[String]){ + + val header = new FileHeader(25,"block_0"); + val buf = header.buffer; + val newheader = FileHeader.create(buf); + System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) + + } +} + diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala new file mode 100644 index 0000000000..d8d35bfeec --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -0,0 +1,88 @@ +package spark.network.netty + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundByteHandlerAdapter +import io.netty.util.CharsetUtil + +import java.util.concurrent.atomic.AtomicInteger +import java.util.logging.Logger +import spark.Logging +import spark.network.ConnectionManagerId +import java.util.concurrent.Executors + +private[spark] class ShuffleCopier extends Logging { + + def getBlock(cmId: ConnectionManagerId, + blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + val handler = new ShuffleClientHandler(resultCollectCallback) + val fc = new FileClient(handler) + fc.init() + fc.connect(cmId.host, cmId.port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } + + def getBlocks(cmId: ConnectionManagerId, + blocks: Seq[(String, Long)], + resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + + blocks.map { + case(blockId,size) => { + getBlock(cmId,blockId,resultCollectCallback) + } + } + } +} + +private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging { + + def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } +} + +private[spark] object ShuffleCopier extends Logging { + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = { + logInfo("File: " + blockId + " content is : \" " + + content.toString(CharsetUtil.UTF_8) + "\"") + } + + def runGetBlock(host:String, port:Int, file:String){ + val handler = new ShuffleClientHandler(echoResultCollectCallBack) + val fc = new FileClient(handler) + fc.init(); + fc.connect(host, port) + fc.sendRequest(file) + fc.waitForClose(); + fc.close() + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleCopier ") + System.exit(1) + } + val host = args(0) + val port = args(1).toInt + val file = args(2) + val threads = if (args.length>3) args(3).toInt else 10 + + val copiers = Executors.newFixedThreadPool(80) + for (i <- Range(0,threads)){ + val runnable = new Runnable() { + def run() { + runGetBlock(host,port,file) + } + } + copiers.execute(runnable) + } + copiers.shutdown + } + +} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala new file mode 100644 index 0000000000..c1986812e9 --- /dev/null +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -0,0 +1,50 @@ +package spark.network.netty + +import spark.Logging +import java.io.File + + +private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging { + val server = new FileServer(pResolver) + + Runtime.getRuntime().addShutdownHook( + new Thread() { + override def run() { + server.stop() + } + } + ) + + def start() { + server.run(port) + } +} + +private[spark] object ShuffleSender { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: ShuffleSender ") + System.exit(1) + } + val port = args(0).toInt + val subDirsPerLocalDir = args(1).toInt + val localDirs = args.drop(2) map {new File(_)} + val pResovler = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + throw new Exception("Block " + blockId + " is not a shuffle block") + } + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) + val file = new File(subDir, blockId) + return file.getAbsolutePath + } + } + val sender = new ShuffleSender(port, pResovler) + + sender.start() + } +} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..b8b68d4283 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -23,6 +23,8 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer +import spark.network.netty.ShuffleCopier +import io.netty.buffer.ByteBuf private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) @@ -467,6 +469,21 @@ class BlockManager( getLocal(blockId).orElse(getRemote(blockId)) } + /** + * A request to fetch one or more blocks, complete with their sizes + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + * the block (since we want all deserializaton to happen in the calling thread); can also + * represent a fetch failure if size == -1. + */ + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } /** * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined @@ -475,7 +492,12 @@ class BlockManager( */ def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + + if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ + return new NettyBlockFetcherIterator(this, blocksByAddress) + } else { + return new BlockFetcherIterator(this, blocksByAddress) + } } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -908,7 +930,7 @@ class BlockFetcherIterator( if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } - val totalBlocks = blocksByAddress.map(_._2.size).sum + var totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis val localBlockIds = new ArrayBuffer[String]() @@ -974,68 +996,83 @@ class BlockFetcherIterator( } } - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest + def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] } } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } } + remoteRequests } - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) + def getLocalBlocks(){ + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } } - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - startTime = System.currentTimeMillis - for (id <- localBlockIds) { - getLocal(id) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } + def initialize(){ + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + initialize() //an iterator that will read fetched blocks off the queue as they arrive. var resultsGotten = 0 @@ -1066,3 +1103,132 @@ class BlockFetcherIterator( def remoteBytesRead = _remoteBytesRead } + +class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] +) extends BlockFetcherIterator(blockManager,blocksByAddress) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, + results : LinkedBlockingQueue[FetchResult]){ + results.put(new FetchResult( + blockId, blockSize, () => dataDeserialize(blockId, blockData) )) + } + + def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + //keep this to interrupt the threads when necessary + def stopCopiers(copiers : List[_ <: Thread]) { + for (copier <- copiers) { + copier.interrupt() + } + } + + override def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) + val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip ) + } + + override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = totalBlocks; + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0){ + //here we changes the totalBlocks + totalBlocks -= 1 + } else { + throw new SparkException("Negative block size "+blockId) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + remoteRequests + } + + var copiers : List[_ <: Thread] = null + + override def initialize(){ + // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // if all the results has been retrieved + // shutdown the copiers + if (resultsGotten == totalBlocks) { + if( copiers != null ) + stopCopiers(copiers) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..d702bb23e0 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -13,24 +13,35 @@ import scala.collection.mutable.ArrayBuffer import spark.executor.ExecutorExitCode import spark.Utils +import spark.Logging +import spark.network.netty.ShuffleSender +import spark.network.netty.PathResolver /** * Stores BlockManager blocks on disk. */ private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { + extends BlockStore(blockManager) with Logging { val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + var shuffleSender : Thread = null + val thisInstance = this // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. val localDirs = createLocalDirs() val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + addShutdownHook() + if(useNetty){ + startShuffleBlockSender() + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -180,10 +191,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Shutdown hook called") try { localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + if (useNetty && shuffleSender != null) + shuffleSender.stop } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } } }) } + + private def startShuffleBlockSender (){ + try { + val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt + + val pResolver = new PathResolver { + def getAbsolutePath(blockId:String):String = { + if (!blockId.startsWith("shuffle_")) { + return null + } + thisInstance.getFile(blockId).getAbsolutePath() + } + } + shuffleSender = new Thread { + override def run() = { + val sender = new ShuffleSender(port,pResolver) + logInfo("created ShuffleSender binding to port : "+ port) + sender.start + } + } + shuffleSender.setDaemon(true) + shuffleSender.start + + } catch { + case interrupted: InterruptedException => + logInfo("Runner thread for ShuffleBlockSender interrupted") + + case e: Exception => { + logError("Error running ShuffleBlockSender ", e) + if (shuffleSender != null) { + shuffleSender.stop + shuffleSender = null + } + } + } + } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f378b2398..e3645653ee 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -141,7 +141,8 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "org.apache.mesos" % "mesos" % "0.9.0-incubating" + "org.apache.mesos" % "mesos" % "0.9.0-incubating", + "io.netty" % "netty-all" % "4.0.0.Beta2" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala index d8b987ec86..bd0b0e74c1 100644 --- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala @@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket import spark.{Logging, KryoSerializer} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import io.Source +import scala.io.Source import java.io.IOException /** -- cgit v1.2.3 From 6798a09df84fb97e196c84d55cf3e21ad676871f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 7 Apr 2013 17:47:38 +0530 Subject: Add support for building against hadoop2-yarn : adding new maven profile for it --- bagel/pom.xml | 37 +++++++++++ core/pom.xml | 62 +++++++++++++++++++ .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 3 + .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 3 + .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 13 ++++ .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 13 ++++ .../apache/hadoop/mapred/HadoopMapRedUtil.scala | 3 + .../hadoop/mapreduce/HadoopMapReduceUtil.scala | 3 + core/src/main/scala/spark/PairRDDFunctions.scala | 5 +- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 2 +- examples/pom.xml | 43 +++++++++++++ pom.xml | 54 ++++++++++++++++ project/SparkBuild.scala | 34 +++++++++-- repl-bin/pom.xml | 50 +++++++++++++++ repl/pom.xml | 71 ++++++++++++++++++++++ streaming/pom.xml | 37 +++++++++++ 16 files changed, 424 insertions(+), 9 deletions(-) create mode 100644 core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala create mode 100644 core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala diff --git a/bagel/pom.xml b/bagel/pom.xml index 510cff4669..89282161ea 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -102,5 +102,42 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/core/pom.xml b/core/pom.xml index fe9c803728..9baa447662 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -279,5 +279,67 @@ + + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-source + generate-sources + + add-source + + + + src/main/scala + src/hadoop2-yarn/scala + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index ca9f7219de..f286f2cf9c 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index de7b0f81e3..264d421d14 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -6,4 +6,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000..875c0a220b --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,13 @@ + +package org.apache.hadoop.mapred + +import org.apache.hadoop.mapreduce.TaskType + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000..8bc6fb6dea --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,13 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import task.{TaskAttemptContextImpl, JobContextImpl} + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index 35300cea58..a0652d7fc7 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index 7afdbff320..7fdbe322fd 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -7,4 +7,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07efba9e8d..39469fa3c8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -545,8 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = new TaskAttemptID(jobtrackerID, - stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -565,7 +564,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * however we're only going to use this local OutputCommitter for * setupJob/commitJob, so we just use a dummy "map" task. */ - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bdd974590a..901d01ef30 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -57,7 +57,7 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] val conf = confBroadcast.value.value - val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance if (format.isInstanceOf[Configurable]) { diff --git a/examples/pom.xml b/examples/pom.xml index 39cc47c709..9594257ad4 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -118,5 +118,48 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-streaming + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/pom.xml b/pom.xml index 08d1fc12e0..b3134a957d 100644 --- a/pom.xml +++ b/pom.xml @@ -558,5 +558,59 @@ + + + hadoop2-yarn + + 2 + 2.0.3-alpha + + + + + maven-root + Maven root repository + http://repo1.maven.org/maven2/ + + true + + + false + + + + + + + + + org.apache.hadoop + hadoop-client + ${yarn.version} + + + org.apache.hadoop + hadoop-yarn-api + ${yarn.version} + + + org.apache.hadoop + hadoop-yarn-common + ${yarn.version} + + + + org.apache.avro + avro + 1.7.1.cloudera.2 + + + org.apache.avro + avro-ipc + 1.7.1.cloudera.2 + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5f378b2398..f041930b4e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1,3 +1,4 @@ + import sbt._ import sbt.Classpaths.publishTask import Keys._ @@ -10,12 +11,18 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "1.0.4" - val HADOOP_MAJOR_VERSION = "1" + //val HADOOP_VERSION = "1.0.4" + //val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" //val HADOOP_MAJOR_VERSION = "2" + //val HADOOP_YARN = false + + // For Hadoop 2 YARN support + val HADOOP_VERSION = "2.0.3-alpha" + val HADOOP_MAJOR_VERSION = "2" + val HADOOP_YARN = true lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming) @@ -129,7 +136,6 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "com.ning" % "compress-lzf" % "0.8.4", - "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", @@ -142,8 +148,26 @@ object SparkBuild extends Build { "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", "org.apache.mesos" % "mesos" % "0.9.0-incubating" - ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, - unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } + ) ++ ( + if (HADOOP_MAJOR_VERSION == "2") { + if (HADOOP_YARN) { + Seq( + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION + ) + } else { + Seq( + "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION + ) + } + } else { + Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION) + }), + unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / + ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") "src/hadoop2-yarn/scala" else "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala" ) + } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings def rootSettings = sharedSettings ++ Seq( diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index dd720e2291..f9d84fd3c4 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -153,6 +153,56 @@ + + hadoop2-yarn + + hadoop2-yarn + + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-bagel + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-examples + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-repl + ${project.version} + hadoop2-yarn + runtime + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + deb diff --git a/repl/pom.xml b/repl/pom.xml index a3e4606edc..1f885673f4 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -187,5 +187,76 @@ + + hadoop2-yarn + + hadoop2-yarn + + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.spark-project + spark-bagel + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-examples + ${project.version} + hadoop2-yarn + runtime + + + org.spark-project + spark-streaming + ${project.version} + hadoop2-yarn + runtime + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + diff --git a/streaming/pom.xml b/streaming/pom.xml index ec077e8089..fc2e211a42 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -149,5 +149,42 @@ + + hadoop2-yarn + + + org.spark-project + spark-core + ${project.version} + hadoop2-yarn + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2-yarn + + + + + -- cgit v1.2.3 From 2f883c515fe4577f0105e62dd9f395d7de42bd68 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Tue, 9 Apr 2013 13:02:50 +0800 Subject: Contiue to update codes for scala code style 1.refactor braces for "class" "if" "while" "for" "match" 2.make code lines less than 100 3.refactor class parameter and extends defination --- core/src/main/scala/spark/SparkContext.scala | 12 +- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 3 +- core/src/main/scala/spark/scheduler/TaskSet.scala | 9 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 6 +- .../cluster/FIFOTaskSetQueuesManager.scala | 10 +- .../cluster/FairTaskSetQueuesManager.scala | 139 +++++++++------------ .../main/scala/spark/scheduler/cluster/Pool.scala | 46 ++++--- .../spark/scheduler/cluster/Schedulable.scala | 3 +- .../scheduler/cluster/SchedulingAlgorithm.scala | 30 ++--- .../spark/scheduler/cluster/SchedulingMode.scala | 5 +- .../spark/scheduler/cluster/TaskSetManager.scala | 6 +- 11 files changed, 122 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index ed5f686379..7c96ae637b 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -146,7 +146,8 @@ class SparkContext( case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). + newInstance().asInstanceOf[TaskSetQueuesManager] scheduler.initialize(backend, taskSetQueuesManager) scheduler @@ -166,7 +167,8 @@ class SparkContext( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). + newInstance().asInstanceOf[TaskSetQueuesManager] scheduler.initialize(backend, taskSetQueuesManager) backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() @@ -186,7 +188,8 @@ class SparkContext( } else { new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) } - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")).newInstance().asInstanceOf[TaskSetQueuesManager] + val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). + newInstance().asInstanceOf[TaskSetQueuesManager] scheduler.initialize(backend, taskSetQueuesManager) scheduler } @@ -602,7 +605,8 @@ class SparkContext( val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,localProperties.value) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler + ,localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 6f4e5cd83e..11fec568c6 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -23,7 +23,8 @@ private[spark] case class JobSubmitted( partitions: Array[Int], allowLocal: Boolean, callSite: String, - listener: JobListener, properties: Properties = null) + listener: JobListener, + properties: Properties = null) extends DAGSchedulerEvent private[spark] case class CompletionEvent( diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index 2498e8a5aa..e4b5fcaedb 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -6,8 +6,13 @@ import java.util.Properties * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ -private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + attempt +private[spark] class TaskSet( + val tasks: Array[Task[_]], + val stageId: Int, + val attempt: Int, + val priority: Int, + val properties: Properties) { + val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index be0d480aa0..2ddac0ff30 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -155,13 +155,11 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray - for (i <- 0 until offers.size) - { + for (i <- 0 until offers.size){ var launchedTask = true val execId = offers(i).executorId val host = offers(i).hostname - while (availableCpus(i) > 0 && launchedTask) - { + while (availableCpus(i) > 0 && launchedTask){ launchedTask = false taskSetQueuesManager.receiveOffer(execId,host,availableCpus(i)) match { case Some(task) => diff --git a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala index 5949ee773f..62d3130341 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala @@ -28,13 +28,11 @@ private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) } - override def receiveOffer(execId:String, host:String,avaiableCpus:Double):Option[TaskDescription] = - { - for(manager <- activeTaskSetsQueue.sortWith(tasksetSchedulingAlgorithm.comparator)) - { + override def receiveOffer(execId:String, host:String,avaiableCpus:Double):Option[TaskDescription] = { + + for (manager <- activeTaskSetsQueue.sortWith(tasksetSchedulingAlgorithm.comparator)) { val task = manager.slaveOffer(execId,host,avaiableCpus) - if (task != None) - { + if (task != None) { return task } } diff --git a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala index 0609600f35..89b74fbb47 100644 --- a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala @@ -38,71 +38,17 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with loadPoolProperties() - def loadPoolProperties() { - //first check if the file exists - val file = new File(schedulerAllocFile) - if(file.exists()) - { - val xml = XML.loadFile(file) - for (poolNode <- (xml \\ POOL_POOLS_PROPERTY)) { - - val poolName = (poolNode \ POOL_POOL_NAME_PROPERTY).text - var schedulingMode = POOL_DEFAULT_SCHEDULING_MODE - var minShares = POOL_DEFAULT_MINIMUM_SHARES - var weight = POOL_DEFAULT_WEIGHT - - val xmlSchedulingMode = (poolNode \ POOL_SCHEDULING_MODE_PROPERTY).text - if( xmlSchedulingMode != "") - { - try - { - schedulingMode = SchedulingMode.withName(xmlSchedulingMode) - } - catch{ - case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") - } - } - - val xmlMinShares = (poolNode \ POOL_MINIMUM_SHARES_PROPERTY).text - if(xmlMinShares != "") - { - minShares = xmlMinShares.toInt - } - - val xmlWeight = (poolNode \ POOL_WEIGHT_PROPERTY).text - if(xmlWeight != "") - { - weight = xmlWeight.toInt - } - - val pool = new Pool(poolName,schedulingMode,minShares,weight) - pools += pool - poolNameToPool(poolName) = pool - logInfo("Create new pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,schedulingMode,minShares,weight)) - } - } - - if(!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) - { - val pool = new Pool(POOL_DEFAULT_POOL_NAME, POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) - pools += pool - poolNameToPool(POOL_DEFAULT_POOL_NAME) = pool - logInfo("Create default pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(POOL_DEFAULT_POOL_NAME,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) - } - } - override def addTaskSetManager(manager: TaskSetManager) { var poolName = POOL_DEFAULT_POOL_NAME - if(manager.taskSet.properties != null) - { + if (manager.taskSet.properties != null) { poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) - if(!poolNameToPool.contains(poolName)) - { + if (!poolNameToPool.contains(poolName)) { //we will create a new pool that user has configured in app instead of being defined in xml file val pool = new Pool(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) pools += pool poolNameToPool(poolName) = pool - logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) + logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( + poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) } } poolNameToPool(poolName).addTaskSetManager(manager) @@ -110,10 +56,8 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with } override def removeTaskSetManager(manager: TaskSetManager) { - var poolName = POOL_DEFAULT_POOL_NAME - if(manager.taskSet.properties != null) - { + if (manager.taskSet.properties != null) { poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) } logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id,poolName)) @@ -124,8 +68,7 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with override def taskFinished(manager: TaskSetManager) { var poolName = POOL_DEFAULT_POOL_NAME - if(manager.taskSet.properties != null) - { + if (manager.taskSet.properties != null) { poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) } val pool = poolNameToPool(poolName) @@ -139,19 +82,15 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with } } - override def receiveOffer(execId: String,host:String,avaiableCpus:Double):Option[TaskDescription] = - { - + override def receiveOffer(execId: String,host:String,avaiableCpus:Double):Option[TaskDescription] = { val sortedPools = pools.sortWith(poolScheduleAlgorithm.comparator) - for(pool <- sortedPools) - { - logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format(pool.poolName,pool.activeTaskSetsQueue.length,pool.minShare,pool.runningTasks)) + for (pool <- sortedPools) { + logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format( + pool.poolName,pool.activeTaskSetsQueue.length,pool.minShare,pool.runningTasks)) } - for (pool <- sortedPools) - { + for (pool <- sortedPools) { val task = pool.receiveOffer(execId,host,avaiableCpus) - if(task != None) - { + if(task != None) { pool.runningTasks += 1 return task } @@ -159,14 +98,60 @@ private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with return None } - override def checkSpeculatableTasks(): Boolean = - { + override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (pool <- pools) - { + for (pool <- pools) { shouldRevive |= pool.checkSpeculatableTasks() } return shouldRevive } + def loadPoolProperties() { + //first check if the file exists + val file = new File(schedulerAllocFile) + if (file.exists()) { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ POOL_POOLS_PROPERTY)) { + + val poolName = (poolNode \ POOL_POOL_NAME_PROPERTY).text + var schedulingMode = POOL_DEFAULT_SCHEDULING_MODE + var minShares = POOL_DEFAULT_MINIMUM_SHARES + var weight = POOL_DEFAULT_WEIGHT + + val xmlSchedulingMode = (poolNode \ POOL_SCHEDULING_MODE_PROPERTY).text + if (xmlSchedulingMode != "") { + try{ + schedulingMode = SchedulingMode.withName(xmlSchedulingMode) + } + catch{ + case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + } + } + + val xmlMinShares = (poolNode \ POOL_MINIMUM_SHARES_PROPERTY).text + if (xmlMinShares != "") { + minShares = xmlMinShares.toInt + } + + val xmlWeight = (poolNode \ POOL_WEIGHT_PROPERTY).text + if (xmlWeight != "") { + weight = xmlWeight.toInt + } + + val pool = new Pool(poolName,schedulingMode,minShares,weight) + pools += pool + poolNameToPool(poolName) = pool + logInfo("Create new pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( + poolName,schedulingMode,minShares,weight)) + } + } + + if (!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) { + val pool = new Pool(POOL_DEFAULT_POOL_NAME, POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) + pools += pool + poolNameToPool(POOL_DEFAULT_POOL_NAME) = pool + logInfo("Create default pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( + POOL_DEFAULT_POOL_NAME,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) + } + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index 8fdca5d2b4..e0917ca1ca 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -7,8 +7,13 @@ import spark.scheduler.cluster.SchedulingMode.SchedulingMode /** * An Schedulable entity that represent collection of TaskSetManager */ -private[spark] class Pool(val poolName: String,val schedulingMode: SchedulingMode, initMinShare:Int, initWeight:Int) extends Schedulable with Logging -{ +private[spark] class Pool( + val poolName: String, + val schedulingMode: SchedulingMode, + initMinShare:Int, + initWeight:Int) + extends Schedulable + with Logging { var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -19,10 +24,8 @@ private[spark] class Pool(val poolName: String,val schedulingMode: SchedulingMod val priority = 0 val stageId = 0 - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = - { - schedulingMode match - { + var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + schedulingMode match { case SchedulingMode.FAIR => val schedule = new FairSchedulingAlgorithm() schedule @@ -32,43 +35,36 @@ private[spark] class Pool(val poolName: String,val schedulingMode: SchedulingMod } } - def addTaskSetManager(manager:TaskSetManager) - { + def addTaskSetManager(manager:TaskSetManager) { activeTaskSetsQueue += manager } - def removeTaskSetManager(manager:TaskSetManager) - { + def removeTaskSetManager(manager:TaskSetManager) { activeTaskSetsQueue -= manager } - def removeExecutor(executorId: String, host: String) - { + def removeExecutor(executorId: String, host: String) { activeTaskSetsQueue.foreach(_.executorLost(executorId,host)) } - def checkSpeculatableTasks(): Boolean = - { + def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for(ts <- activeTaskSetsQueue) - { + for (ts <- activeTaskSetsQueue) { shouldRevive |= ts.checkSpeculatableTasks() } return shouldRevive } - def receiveOffer(execId:String,host:String,availableCpus:Double):Option[TaskDescription] = - { + def receiveOffer(execId:String,host:String,availableCpus:Double):Option[TaskDescription] = { val sortedActiveTasksSetQueue = activeTaskSetsQueue.sortWith(taskSetSchedulingAlgorithm.comparator) - for(manager <- sortedActiveTasksSetQueue) - { - logDebug("poolname:%s,taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format(poolName,manager.taskSet.id,manager.numTasks,manager.minShare,manager.weight,manager.runningTasks)) + for (manager <- sortedActiveTasksSetQueue) { + logDebug("poolname:%s,taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format( + poolName,manager.taskSet.id,manager.numTasks,manager.minShare,manager.weight,manager.runningTasks)) } - for(manager <- sortedActiveTasksSetQueue) - { + + for (manager <- sortedActiveTasksSetQueue) { val task = manager.slaveOffer(execId,host,availableCpus) - if (task != None) - { + if (task != None) { manager.runningTasks += 1 return task } diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala index 6f4f104f42..8dfc369c03 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -1,7 +1,8 @@ package spark.scheduler.cluster /** - * An interface for schedulable entities, there are two type Schedulable entities(Pools and TaskSetManagers) + * An interface for schedulable entities. + * there are two type of Schedulable entities(Pools and TaskSetManagers) */ private[spark] trait Schedulable { def weight:Int diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index 2f8123587f..ac2237a7ef 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -9,34 +9,25 @@ private[spark] trait SchedulingAlgorithm { def comparator(s1: Schedulable,s2: Schedulable): Boolean } -private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm -{ - override def comparator(s1: Schedulable, s2: Schedulable): Boolean = - { +private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { val priority1 = s1.priority val priority2 = s2.priority var res = Math.signum(priority1 - priority2) - if (res == 0) - { + if (res == 0) { val stageId1 = s1.stageId val stageId2 = s2.stageId res = Math.signum(stageId1 - stageId2) } if (res < 0) - { return true - } else - { return false - } } } -private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm -{ - def comparator(s1: Schedulable, s2:Schedulable): Boolean = - { +private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { + def comparator(s1: Schedulable, s2:Schedulable): Boolean = { val minShare1 = s1.minShare val minShare2 = s2.minShare val runningTasks1 = s1.runningTasks @@ -49,22 +40,15 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true - if(s1Needy && !s2Needy) - { + if (s1Needy && !s2Needy) res = true - } else if(!s1Needy && s2Needy) - { res = false - } else if (s1Needy && s2Needy) - { res = minShareRatio1 <= minShareRatio2 - } else - { res = taskToWeightRatio1 <= taskToWeightRatio2 - } + return res } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala index 480af2c1a3..6e0c6793e0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala @@ -1,8 +1,7 @@ package spark.scheduler.cluster -object SchedulingMode extends Enumeration("FAIR","FIFO") -{ - type SchedulingMode = Value +object SchedulingMode extends Enumeration("FAIR","FIFO"){ + type SchedulingMode = Value val FAIR,FIFO = Value } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index ddc4fa6642..7ec2f69da5 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,7 +17,11 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Schedulable with Logging { +private[spark] class TaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet) + extends Schedulable + with Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong -- cgit v1.2.3 From d90d2af1036e909f81cf77c85bfe589993c4f9f3 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 15 Apr 2013 18:12:11 +0530 Subject: Checkpoint commit - compiles and passes a lot of tests - not all though, looking into FileSuite issues --- .../scala/spark/deploy/SparkHadoopUtil.scala | 18 + .../scala/spark/deploy/SparkHadoopUtil.scala | 59 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 342 +++++++++++++ .../deploy/yarn/ApplicationMasterArguments.scala | 78 +++ .../scala/spark/deploy/yarn/Client.scala | 326 ++++++++++++ .../scala/spark/deploy/yarn/ClientArguments.scala | 104 ++++ .../scala/spark/deploy/yarn/WorkerRunnable.scala | 171 +++++++ .../spark/deploy/yarn/YarnAllocationHandler.scala | 547 +++++++++++++++++++++ .../scheduler/cluster/YarnClusterScheduler.scala | 42 ++ .../scala/spark/deploy/SparkHadoopUtil.scala | 18 + core/src/main/scala/spark/ClosureCleaner.scala | 12 +- .../main/scala/spark/FetchFailedException.scala | 25 +- core/src/main/scala/spark/Logging.scala | 4 + core/src/main/scala/spark/MapOutputTracker.scala | 100 ++-- core/src/main/scala/spark/SparkContext.scala | 46 +- core/src/main/scala/spark/SparkEnv.scala | 15 +- core/src/main/scala/spark/Utils.scala | 138 +++++- .../main/scala/spark/api/python/PythonRDD.scala | 2 + .../main/scala/spark/deploy/DeployMessage.scala | 19 +- .../src/main/scala/spark/deploy/JsonProtocol.scala | 1 + .../scala/spark/deploy/LocalSparkCluster.scala | 8 +- .../main/scala/spark/deploy/client/Client.scala | 6 +- .../scala/spark/deploy/client/ClientListener.scala | 2 +- .../scala/spark/deploy/client/TestClient.scala | 2 +- .../main/scala/spark/deploy/master/Master.scala | 16 +- .../spark/deploy/master/MasterArguments.scala | 17 +- .../scala/spark/deploy/master/WorkerInfo.scala | 9 + .../scala/spark/deploy/worker/ExecutorRunner.scala | 6 +- .../main/scala/spark/deploy/worker/Worker.scala | 19 +- .../spark/deploy/worker/WorkerArguments.scala | 13 +- core/src/main/scala/spark/executor/Executor.scala | 7 +- .../spark/executor/StandaloneExecutorBackend.scala | 31 +- core/src/main/scala/spark/network/Connection.scala | 192 ++++++-- .../scala/spark/network/ConnectionManager.scala | 390 ++++++++++----- core/src/main/scala/spark/network/Message.scala | 3 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 18 +- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 4 + .../scala/spark/scheduler/InputFormatInfo.scala | 156 ++++++ .../main/scala/spark/scheduler/ResultTask.scala | 10 +- .../scala/spark/scheduler/ShuffleMapTask.scala | 12 +- .../src/main/scala/spark/scheduler/SplitInfo.scala | 61 +++ .../main/scala/spark/scheduler/TaskScheduler.scala | 4 + .../spark/scheduler/TaskSchedulerListener.scala | 3 + .../spark/scheduler/cluster/ClusterScheduler.scala | 276 ++++++++++- .../cluster/SparkDeploySchedulerBackend.scala | 8 +- .../cluster/StandaloneClusterMessage.scala | 7 +- .../cluster/StandaloneSchedulerBackend.scala | 33 +- .../scala/spark/scheduler/cluster/TaskInfo.scala | 9 +- .../spark/scheduler/cluster/TaskSetManager.scala | 309 +++++++++--- .../spark/scheduler/cluster/WorkerOffer.scala | 2 +- .../spark/scheduler/local/LocalScheduler.scala | 4 +- .../main/scala/spark/storage/BlockManager.scala | 123 +++-- .../main/scala/spark/storage/BlockManagerId.scala | 40 +- .../spark/storage/BlockManagerMasterActor.scala | 20 +- .../scala/spark/storage/BlockMessageArray.scala | 1 + core/src/main/scala/spark/storage/DiskStore.scala | 33 +- .../src/main/scala/spark/storage/MemoryStore.scala | 4 +- .../main/scala/spark/storage/StorageLevel.scala | 8 +- core/src/main/scala/spark/util/AkkaUtils.scala | 13 +- .../main/scala/spark/util/TimeStampedHashMap.scala | 8 + .../twirl/spark/deploy/master/index.scala.html | 2 +- .../twirl/spark/deploy/worker/index.scala.html | 2 +- .../twirl/spark/storage/worker_table.scala.html | 2 +- core/src/test/scala/spark/DistributedSuite.scala | 2 +- core/src/test/scala/spark/FileSuite.scala | 1 + .../scala/spark/scheduler/DAGSchedulerSuite.scala | 2 +- 66 files changed, 3488 insertions(+), 477 deletions(-) create mode 100644 core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala create mode 100644 core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala create mode 100644 core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala create mode 100644 core/src/main/scala/spark/scheduler/InputFormatInfo.scala create mode 100644 core/src/main/scala/spark/scheduler/SplitInfo.scala diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..d4badbc5c4 --- /dev/null +++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,18 @@ +package spark.deploy + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..66e5ad8491 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,59 @@ +package spark.deploy + +import collection.mutable.HashMap +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import java.security.PrivilegedExceptionAction + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + val yarnConf = new YarnConfiguration(new Configuration()) + + def getUserNameFromEnvironment(): String = { + // defaulting to env if -D is not present ... + val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name)) + + // If nothing found, default to user we are running as + if (retval == null) System.getProperty("user.name") else retval + } + + def runAsUser(func: (Product) => Unit, args: Product) { + runAsUser(func, args, getUserNameFromEnvironment()) + } + + def runAsUser(func: (Product) => Unit, args: Product, user: String) { + + // println("running as user " + jobUserName) + + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + func(args) + // no return value ... + null + } + }) + } + + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. + def isYarnMode(): Boolean = { + val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")) + java.lang.Boolean.valueOf(yarnMode) + } + + // Set an env variable indicating we are running in YARN mode. + // Note that anything with SPARK prefix gets propagated to all (remote) processes + def setYarnMode() { + System.setProperty("SPARK_YARN_MODE", "true") + } + + def setYarnMode(env: HashMap[String, String]) { + env("SPARK_YARN_MODE") = "true" + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000..65361e0ed9 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,342 @@ +package spark.deploy.yarn + +import java.net.Socket +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import scala.collection.JavaConversions._ +import spark.{SparkContext, Logging, Utils} +import org.apache.hadoop.security.UserGroupInformation +import java.security.PrivilegedExceptionAction + +class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var userThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + + def run() { + + // Initialization + val jobUserName = Utils.getUserNameFromEnvironment() + logInfo("running as user " + jobUserName) + + // run as user ... + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + runImpl() + return null + } + }) + } + + private def runImpl() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + // Workaround until hadoop moves to something which has + // https://issues.apache.org/jira/browse/HADOOP-8406 + // ignore result + // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times + // Hence args.workerCores = numCore disabled above. Any better option ? + // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + + ApplicationMaster.register(this) + // Start the user's JAR + userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.master.port property has + // been set by the Thread executing the user class. + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Wait for the user class to Finish + userThread.join() + + // Finish the ApplicationMaster + finishApplicationMaster() + // TODO: Exit based on success/failure + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark master to be reachable.") + var masterUp = false + while(!masterUp) { + val masterHost = System.getProperty("spark.master.host") + val masterPort = System.getProperty("spark.master.port") + try { + val socket = new Socket(masterHost, masterPort.toInt) + socket.close() + logInfo("Master now available: " + masterHost + ":" + masterPort) + masterUp = true + } catch { + case e: Exception => + logError("Failed to connect to master at " + masterHost + ":" + masterPort) + Thread.sleep(100) + } + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader) + .getMethod("main", classOf[Array[String]]) + val t = new Thread { + override def run() { + var mainArgs: Array[String] = null + var startIndex = 0 + + // I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now ! + if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") { + // ensure that first param is ALWAYS "yarn-standalone" + mainArgs = new Array[String](args.userArgs.size() + 1) + mainArgs.update(0, "yarn-standalone") + startIndex = 1 + } + else { + mainArgs = new Array[String](args.userArgs.size()) + } + + args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size()) + + mainMethod.invoke(null, mainArgs) + } + } + t.start() + return t + } + + private def allocateWorkers() { + logInfo("Waiting for spark context initialization") + + try { + var sparkContext: SparkContext = null + ApplicationMaster.sparkContextRef.synchronized { + var count = 0 + while (ApplicationMaster.sparkContextRef.get() == null) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + ApplicationMaster.sparkContextRef.wait(10000L) + } + sparkContext = ApplicationMaster.sparkContextRef.get() + assert(sparkContext != null) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData) + } + + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers && + // If user thread exists, then quit ! + userThread.isAlive) { + + this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + ApplicationMaster.incrementAllocatorLoop(1) + Thread.sleep(100) + } + } finally { + // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + logInfo("All workers have launched.") + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + if (userThread.isAlive){ + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + launchReporterThread(interval) + } + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (userThread.isAlive){ + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + /* + def printContainers(containers: List[Container]) = { + for (container <- containers) { + logInfo("Launching shell command on a new container." + + ", containerId=" + container.getId() + + ", containerNode=" + container.getNodeId().getHost() + + ":" + container.getNodeId().getPort() + + ", containerNodeURI=" + container.getNodeHttpAddress() + + ", containerState" + container.getState() + + ", containerResourceMemory" + + container.getResource().getMemory()) + } + } + */ + + def finishApplicationMaster() { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + // TODO: Check if the application has failed or succeeded + finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED) + resourceManager.finishApplicationMaster(finishReq) + } + +} + +object ApplicationMaster { + // number of times to wait for the allocator loop to complete. + // each loop iteration waits for 100ms, so maximum of 3 seconds. + // This is to ensure that we have reasonable number of containers before we start + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more + // containers are available. Might need to handle this better. + private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + def incrementAllocatorLoop(by: Int) { + val count = yarnAllocatorLoop.getAndAdd(by) + if (count >= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.synchronized { + // to wake threads off wait ... + yarnAllocatorLoop.notifyAll() + } + } + } + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + def register(master: ApplicationMaster) { + applicationMasters.add(master) + } + + val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + + def sparkContextInitialized(sc: SparkContext): Boolean = { + var modified = false + sparkContextRef.synchronized { + modified = sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + + // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit + // Should not really have to do this, but it helps yarn to evict resources earlier. + // not to mention, prevent Client declaring failure even though we exit'ed properly. + if (modified) { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { + // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run' + logInfo("Adding shutdown hook for context " + sc) + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + // best case ... + for (master <- applicationMasters) master.finishApplicationMaster + } + } ) + } + + // Wait for initialization to complete and atleast 'some' nodes can get allocated + yarnAllocatorLoop.synchronized { + while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.wait(1000L) + } + } + modified + } + + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new ApplicationMaster(args).run() + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala new file mode 100644 index 0000000000..dc89125d81 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -0,0 +1,78 @@ +package spark.deploy.yarn + +import spark.util.IntParam +import collection.mutable.ArrayBuffer + +class ApplicationMasterArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer = new ArrayBuffer[String]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: IntParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + } + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + System.exit(exitCode) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000000..7fa6740579 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -0,0 +1,326 @@ +package spark.deploy.yarn + +import java.net.{InetSocketAddress, URI} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ +import spark.{Logging, Utils} +import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import spark.deploy.SparkHadoopUtil + +class Client(conf: Configuration, args: ClientArguments) extends Logging { + + def this(args: ClientArguments) = this(new Configuration(), args) + + var applicationsManager: ClientRMProtocol = null + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run() { + connectToASM() + logClusterResourceDetails() + + val newApp = getNewApplication() + val appId = newApp.getApplicationId() + + verifyClusterResources(newApp) + val appContext = createApplicationSubmissionContext(appId) + val localResources = prepareLocalResources(appId, "spark") + val env = setupLaunchEnv(localResources) + val amContainer = createContainerLaunchContext(newApp, localResources, env) + + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + appContext.setUser(args.amUser) + + submitApp(appContext) + + monitorApplication(appId) + System.exit(0) + } + + + def connectToASM() { + val rmAddress: InetSocketAddress = NetUtils.createSocketAddr( + yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS) + ) + logInfo("Connecting to ResourceManager at" + rmAddress) + applicationsManager = rpc.getProxy(classOf[ClientRMProtocol], rmAddress, conf) + .asInstanceOf[ClientRMProtocol] + } + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = getYarnClusterMetrics + logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers) + +/* + val clusterNodeReports: List[NodeReport] = getNodeReports + logDebug("Got Cluster node info from ASM") + for (node <- clusterNodeReports) { + logDebug("Got node report from ASM for, nodeId=" + node.getNodeId + ", nodeAddress=" + node.getHttpAddress + + ", nodeRackName=" + node.getRackName + ", nodeNumContainers=" + node.getNumContainers + ", nodeHealthStatus=" + node.getNodeHealthStatus) + } +*/ + + val queueInfo: QueueInfo = getQueueInfo(args.amQueue) + logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) + } + + def getYarnClusterMetrics: YarnClusterMetrics = { + val request: GetClusterMetricsRequest = Records.newRecord(classOf[GetClusterMetricsRequest]) + val response: GetClusterMetricsResponse = applicationsManager.getClusterMetrics(request) + return response.getClusterMetrics + } + + def getNodeReports: List[NodeReport] = { + val request: GetClusterNodesRequest = Records.newRecord(classOf[GetClusterNodesRequest]) + val response: GetClusterNodesResponse = applicationsManager.getClusterNodes(request) + return response.getNodeReports.toList + } + + def getQueueInfo(queueName: String): QueueInfo = { + val request: GetQueueInfoRequest = Records.newRecord(classOf[GetQueueInfoRequest]) + request.setQueueName(queueName) + request.setIncludeApplications(true) + request.setIncludeChildQueues(false) + request.setRecursive(false) + Records.newRecord(classOf[GetQueueInfoRequest]) + return applicationsManager.getQueueInfo(request).getQueueInfo + } + + def getNewApplication(): GetNewApplicationResponse = { + logInfo("Requesting new Application") + val request = Records.newRecord(classOf[GetNewApplicationRequest]) + val response = applicationsManager.getNewApplication(request) + logInfo("Got new ApplicationId: " + response.getApplicationId()) + return response + } + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of resources in this cluster " + maxMem) + + // If the cluster does not have enough memory resources, exit. + val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory + if (requestedMem > maxMem) { + logError("Cluster cannot satisfy memory resource request of " + requestedMem) + System.exit(1) + } + } + + def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { + logInfo("Setting up application submission context for ASM") + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + appContext.setApplicationId(appId) + appContext.setApplicationName("Spark") + return appContext + } + + def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + // Upload Spark and the application JAR to the remote file system + // Add them as local resources to the AM + val fs = FileSystem.get(conf) + Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + .foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + val src = new Path(localPath) + val pathSuffix = appName + "/" + appId.getId() + destName + val dst = new Path(fs.getHomeDirectory(), pathSuffix) + logInfo("Uploading " + src + " to " + dst) + fs.copyFromLocalFile(false, true, src, dst) + val destStatus = fs.getFileStatus(dst) + + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(LocalResourceType.FILE) + amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + locaResources(destName) = amJarRsrc + } + } + return locaResources + } + + def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + + val env = new HashMap[String, String]() + Apps.addToEnvironment(env, Environment.USER.name, args.amUser) + + // If log4j present, ensure ours overrides all others + if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Client.populateHadoopClasspath(yarnConf, env) + SparkHadoopUtil.setYarnMode(env) + env("SPARK_YARN_JAR_PATH") = + localResources("spark.jar").getResource().getScheme.toString() + "://" + + localResources("spark.jar").getResource().getFile().toString() + env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() + env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() + + env("SPARK_YARN_USERJAR_PATH") = + localResources("app.jar").getResource().getScheme.toString() + "://" + + localResources("app.jar").getResource().getFile().toString() + env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() + env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() + + if (log4jConfLocalRes != null) { + env("SPARK_YARN_LOG4J_PATH") = + log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() + env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() + env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() + } + + // Add each SPARK-* key to the environment + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + + retval.toString + } + + def createContainerLaunchContext(newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + + var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for am memory + JAVA_OPTS += "-Xmx" + amMemory + "m " + + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. + if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) { + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + // Command for the ApplicationMaster + val commands = List[String]("java " + + " -server " + + JAVA_OPTS + + " spark.deploy.yarn.ApplicationMaster" + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Command for the ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + // Memory for the ApplicationMaster + capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + amContainer.setResource(capability) + + return amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Create the request to send to the applications manager + val appRequest = Records.newRecord(classOf[SubmitApplicationRequest]) + .asInstanceOf[SubmitApplicationRequest] + appRequest.setApplicationSubmissionContext(appContext) + // Submit the application to the applications manager + logInfo("Submitting application to ASM") + applicationsManager.submitApplication(appRequest) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + while(true) { + Thread.sleep(1000) + val reportRequest = Records.newRecord(classOf[GetApplicationReportRequest]) + .asInstanceOf[GetApplicationReportRequest] + reportRequest.setApplicationId(appId) + val reportResponse = applicationsManager.getApplicationReport(reportRequest) + val report = reportResponse.getApplicationReport() + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToken: " + report.getClientToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + return true + } +} + +object Client { + def main(argStrings: Array[String]) { + val args = new ClientArguments(argStrings) + SparkHadoopUtil.setYarnMode() + new Client(args).run + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala new file mode 100644 index 0000000000..53b305f7df --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -0,0 +1,104 @@ +package spark.deploy.yarn + +import spark.util.MemoryParam +import spark.util.IntParam +import collection.mutable.{ArrayBuffer, HashMap} +import spark.scheduler.{InputFormatInfo, SplitInfo} + +// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware ! +class ClientArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + var amUser = System.getProperty("user.name") + var amQueue = System.getProperty("QUEUE", "default") + var amMemory: Int = 512 + // TODO + var inputFormatInfo: List[InputFormatInfo] = null + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() + val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: MemoryParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case ("--user") :: value :: tail => + amUser = value + args = tail + + case ("--queue") :: value :: tail => + amQueue = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + inputFormatInfo = inputFormatMap.values.toList + } + + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.Client [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --user USERNAME Run the ApplicationMaster as a different user\n" + ) + System.exit(exitCode) + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000000..5688f1ab66 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,171 @@ +package spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import spark.{Logging, Utils} + +class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String, + slaveId: String, hostname: String, workerMemory: Int, workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var cm: ContainerManager = null + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + cm = connectToCM + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + ctx.setContainerId(container.getId()) + ctx.setResource(container.getResource()) + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + val commands = List[String]("java " + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " spark.executor.StandaloneExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + val startReq = Records.newRecord(classOf[StartContainerRequest]) + .asInstanceOf[StartContainerRequest] + startReq.setContainerLaunchContext(ctx) + cm.startContainer(startReq) + } + + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + + // Spark JAR + val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + sparkJarResource.setType(LocalResourceType.FILE) + sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_JAR_PATH")))) + sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) + sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) + locaResources("spark.jar") = sparkJarResource + // User JAR + val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + userJarResource.setType(LocalResourceType.FILE) + userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) + userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) + userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) + locaResources("app.jar") = userJarResource + + // Log4j conf - if available + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + log4jConfResource.setType(LocalResourceType.FILE) + log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) + log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) + log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) + log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) + locaResources("log4j.properties") = log4jConfResource + } + + + logInfo("Prepared Local resources " + locaResources) + return locaResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + // should we add this ? + Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment()) + + // If log4j present, ensure ours overrides all others + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + // Which is correct ? + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + } + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Client.populateHadoopClasspath(yarnConf, env) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def connectToCM: ContainerManager = { + val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + val cmAddress = NetUtils.createSocketAddr(cmHostPortStr) + logInfo("Connecting to ContainerManager at " + cmHostPortStr) + return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000000..cac9dab401 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,547 @@ +package spark.deploy.yarn + +import spark.{Logging, Utils} +import spark.scheduler.SplitInfo +import scala.collection +import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container} +import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.util.{RackResolver, Records} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger +import org.apache.hadoop.yarn.api.AMRMProtocol +import collection.JavaConversions._ +import collection.mutable.{ArrayBuffer, HashMap, HashSet} +import org.apache.hadoop.conf.Configuration +import java.util.{Collections, Set => JSet} +import java.lang.{Boolean => JBoolean} + +object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// too many params ? refactor it 'somehow' ? +// needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it +// more proactive and decoupled. +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info +// on how we are requesting for containers. +private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol, + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, val workerMemory: Int, val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int]) + extends Logging { + + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set + // allocatedContainerToHostMap: container to host mapping + private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]() + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def allocateContainers(workersToRequest: Int) { + // We need to send the request only once from what I understand ... but for now, not modifying this much. + + // Keep polling the Resource Manager for containers + val amResp = allocateWorkerResources(workersToRequest).getAMResponse + + val _allocatedContainers = amResp.getAllocatedContainers() + if (_allocatedContainers.size > 0) { + + + logDebug("Allocated " + _allocatedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + logDebug("Cluster Resources: " + amResp.getAvailableResources) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + // ignore if not satisfying constraints { + for (container <- _allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // allocatedContainers += container + + val host = container.getNodeId.getHost + val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) + + containers += container + } + // Add all ignored containers to released list + else releasedContainerList.add(container.getId()) + } + + // Find the appropriate containers to use + // Slightly non trivial groupBy I guess ... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) + { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) + assert(remainingContainers != null) + + if (requiredHostCount >= remainingContainers.size){ + // Since we got <= required containers, add all to dataLocalContainers + dataLocalContainers.put(candidateHost, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredHostCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredHostCount) + // and rest as remainingContainer + val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + // remainingContainers = remaining + + // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : + // add remaining to release list. If we have insufficient containers, next allocation cycle + // will reallocate (but wont treat it as data local) + for (container <- remaining) releasedContainerList.add(container.getId()) + remainingContainers = null + } + + // now rack local + if (remainingContainers != null){ + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + + if (rack != null){ + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.get(rack).getOrElse(List()).size + + + if (requiredRackCount >= remainingContainers.size){ + // Add all to dataLocalContainers + dataLocalContainers.put(rack, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredRackCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredRackCount) + // and rest as remainingContainer + val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + remainingContainers = remaining + } + } + } + + // If still not consumed, then it is off rack host - add to that list. + if (remainingContainers != null){ + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order : + // first host local, then rack local and then off rack (everything else). + // Note that the list we create below tries to ensure that not all containers end up within a host + // if there are sufficiently large number of hosts/containers. + + val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers + for (container <- allocatedContainers) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("Ignoring container " + containerId + " at host " + workerHostname + + " .. we already have required number of containers") + releasedContainerList.add(containerId) + // reset counter back to old value. + numWorkersRunning.decrementAndGet() + } + else { + // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter) + val workerId = workerIdCounter.incrementAndGet().toString + val masterUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + + logInfo("launching container on " + containerId + " host " + workerHostname) + // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but .. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + + new Thread( + new WorkerRunnable(container, conf, masterUrl, workerId, + workerHostname, workerMemory, workerCores) + ).start() + } + } + logDebug("After allocated " + allocatedContainers.size + " containers (orig : " + + _allocatedContainers.size + "), current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + + + val completedContainers = amResp.getCompletedContainersStatuses() + if (completedContainers.size > 0){ + logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + + for (completedContainer <- completedContainers){ + val containerId = completedContainer.getContainerId + + // Was this released by us ? If yes, then simply remove from containerSet and move on. + if (pendingReleaseContainers.containsKey(containerId)) { + pendingReleaseContainers.remove(containerId) + } + else { + // simply decrement count - next iteration of ReporterThread will take care of allocating ! + numWorkersRunning.decrementAndGet() + logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState + + " httpaddress: " + completedContainer.getDiagnostics) + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) + assert (host != null) + + val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) + assert (containerSet != null) + + containerSet -= containerId + if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host) + else allocatedHostToContainersMap.update(host, containerSet) + + allocatedContainerToHostMap -= containerId + + // doing this within locked context, sigh ... move to outside ? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) allocatedRackCount.put(rack, rackCount) + else allocatedRackCount.remove(rack) + } + } + } + } + logDebug("After completed " + completedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + } + + def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) + } + + requestedContainers.toList + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { + + var resourceRequests: List[ResourceRequest] = null + + // default. + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty) + resourceRequests = List( + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) + } + else { + // request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests += + createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList) + + val anyContainerRequests: ResourceRequest = + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY) + + val containerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1) + + containerRequests ++= hostContainerRequests + containerRequests ++= rackContainerRequests + containerRequests += anyContainerRequests + + resourceRequests = containerRequests.toList + } + + val req = Records.newRecord(classOf[AllocateRequest]) + req.setResponseId(lastResponseId.incrementAndGet) + req.setApplicationAttemptId(appAttemptId) + + req.addAllAsks(resourceRequests) + + val releasedContainerList = createReleasedContainerList() + req.addAllReleases(releasedContainerList) + + + + if (numWorkers > 0) { + logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.") + } + else { + logDebug("Empty allocation req .. release : " + releasedContainerList) + } + + for (req <- resourceRequests) { + logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers + + ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability) + } + resourceManager.allocate(req) + } + + + private def createResourceRequest(requestType: AllocationType.AllocationType, + resource:String, numWorkers: Int, priority: Int): ResourceRequest = { + + // If hostname specified, we need atleast two requests - node local and rack local. + // There must be a third request - which is ANY : that will be specially handled. + requestType match { + case AllocationType.HOST => { + assert (YarnAllocationHandler.ANY_HOST != resource) + + val hostname = resource + val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + + // add to host->rack mapping + YarnAllocationHandler.populateRackInfo(conf, hostname) + + nodeLocal + } + + case AllocationType.RACK => { + val rack = resource + createResourceRequestImpl(rack, numWorkers, priority) + } + + case AllocationType.ANY => { + createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority) + } + + case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType) + } + } + + private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = { + + val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) + val memCapability = Records.newRecord(classOf[Resource]) + // There probably is some overhead here, let's reserve a bit more memory. + memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + rsrcRequest.setCapability(memCapability) + + val pri = Records.newRecord(classOf[Priority]) + pri.setPriority(priority) + rsrcRequest.setPriority(pri) + + rsrcRequest.setHostName(hostname) + + rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) + rsrcRequest + } + + def createReleasedContainerList(): ArrayBuffer[ContainerId] = { + + val retval = new ArrayBuffer[ContainerId](1) + // iterator on COW list ... + for (container <- releasedContainerList.iterator()){ + retval += container + } + // remove from the original list. + if (! retval.isEmpty) { + releasedContainerList.removeAll(retval) + for (v <- retval) pendingReleaseContainers.put(v, true) + logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + + pendingReleaseContainers) + } + + retval + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // all requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb + val MEMORY_OVERHEAD = 384 + + // host to rack map - saved from allocation requests + // We are expecting this not to change. + // Note that it is possible for this to change : and RM will indicate that to us via update + // response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers, + args.workerMemory, args.workerCores, hostToCount, rackToCount) + } + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + maxWorkers: Int, workerMemory: Int, workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers, + workerMemory, workerCores, hostToCount, rackToCount) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) : + // host to count, rack to count + (Map[String, Int], Map[String, Int]) = { + + if (input == null) return (Map[String, Int](), Map[String, Int]()) + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (! hostToRack.contains(host)) populateRackInfo(conf, host) + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + val set = rackToHostSet.get(rack) + if (set == null) return None + + // No better way to get a Set[String] from JSet ? + val convertedSet: collection.mutable.Set[String] = set + Some(convertedSet.toSet) + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list ? + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala new file mode 100644 index 0000000000..ed732d36bf --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -0,0 +1,42 @@ +package spark.scheduler.cluster + +import spark._ +import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.hadoop.conf.Configuration + +/** + * + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + */ +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate + // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) + // Subsequent creations are ignored - since nodes are already allocated by then. + + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + // By default, if rack is unknown, return nothing + override def getCachedHostsForRack(rack: String): Option[Set[String]] = { + if (rack == None || rack == null) return None + + YarnAllocationHandler.fetchCachedHostsForRack(rack) + } + + override def postStartHook() { + val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + if (sparkContextInitialized){ + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(3000L) + } + logInfo("YarnClusterScheduler.postStartHook done") + } +} diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000..d4badbc5c4 --- /dev/null +++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,18 @@ +package spark.deploy + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } +} diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 98525b99c8..50d6a1c5c9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -8,12 +8,20 @@ import scala.collection.mutable.Set import org.objectweb.asm.{ClassReader, MethodVisitor, Type} import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private def getClassReader(cls: Class[_]): ClassReader = { - new ClassReader(cls.getResourceAsStream( - cls.getName.replaceFirst("^.*\\.", "") + ".class")) + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } // Check whether a class represents a Scala closure diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index a953081d24..40b0193f19 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -3,18 +3,25 @@ package spark import spark.storage.BlockManagerId private[spark] class FetchFailedException( - val bmAddress: BlockManagerId, - val shuffleId: Int, - val mapId: Int, - val reduceId: Int, + taskEndReason: TaskEndReason, + message: String, cause: Throwable) extends Exception { - - override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + + def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), + cause) + + def this (shuffleId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(null, shuffleId, -1, reduceId), + "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) + + override def getMessage(): String = message + override def getCause(): Throwable = cause - def toTaskEndReason: TaskEndReason = - FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = taskEndReason + } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 7c1c1bb144..0fc8c31463 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -68,6 +68,10 @@ trait Logging { if (log.isErrorEnabled) log.error(msg, throwable) } + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + // Method for ensuring that logging is initialized, to avoid having multiple // threads do it concurrently (as SLF4J initialization is not thread safe). protected def initLogging() { log } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 866d630a6d..6e9da02893 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,7 +1,6 @@ package spark import java.io._ -import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap @@ -12,8 +11,7 @@ import akka.dispatch._ import akka.pattern.ask import akka.remote._ import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ + import spark.scheduler.MapStatus import spark.storage.BlockManagerId @@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac private[spark] class MapOutputTracker extends Logging { + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging { // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) @@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. def askTracker(message: Any): Any = { try { - val timeout = 10.seconds val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) } catch { @@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging { } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != None) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging { } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses(shuffleId) - if (array != null) { + var arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + var array = arrayOpt.get array.synchronized { if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null @@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging { } // Remembers which map output locations are currently being fetched on a worker - val fetching = new HashSet[Int] + private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -132,31 +132,49 @@ private[spark] class MapOutputTracker extends Logging { case e: InterruptedException => } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId)) - } else { + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. fetching += shuffleId } } - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val host = System.getProperty("spark.hostname", Utils.localHostName) - // This try-finally prevents hangs due to timeouts: - var fetchedStatuses: Array[MapStatus] = null - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val hostPort = Utils.localHostPort() + // This try-finally prevents hangs due to timeouts: + var fetchedStatuses: Array[MapStatus] = null + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + fetchedStatuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + else{ + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } } else { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } } } @@ -194,7 +212,8 @@ private[spark] class MapOutputTracker extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + mapStatuses.clear() generation = newGen } } @@ -232,10 +251,13 @@ private[spark] class MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(statuses) + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } objOut.close() out.toByteArray } @@ -243,7 +265,10 @@ private[spark] class MapOutputTracker extends Logging { // Opposite of serializeStatuses. def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + objIn.readObject(). + // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present + // comment this out - nulls could be due to missing location ? + asInstanceOf[Array[MapStatus]] // .filter( _ != null ) } } @@ -253,14 +278,11 @@ private[spark] object MapOutputTracker { // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), // throw a FetchFailedException. - def convertMapStatuses( + private def convertMapStatuses( shuffleId: Int, reduceId: Int, statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - if (statuses == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } + assert (statuses != null) statuses.map { status => if (status == null) { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4957a54c1b..e853bce2c4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,7 +37,7 @@ import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import spark.scheduler._ import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} +import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerUI import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -59,7 +59,10 @@ class SparkContext( val appName: String, val sparkHome: String = null, val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map()) + val environment: Map[String, String] = Map(), + // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. + // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) extends Logging { // Ensure logging is initialized before we spawn any threads @@ -67,7 +70,7 @@ class SparkContext( // Set Spark driver host and port system properties if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localIpAddress) + System.setProperty("spark.driver.host", Utils.localHostName()) } if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") @@ -99,7 +102,7 @@ class SparkContext( // Add each JAR given through the constructor - jars.foreach { addJar(_) } + if (jars != null) jars.foreach { addJar(_) } // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() @@ -111,7 +114,7 @@ class SparkContext( executorEnvs(key) = value } } - executorEnvs ++= environment + if (environment != null) executorEnvs ++= environment // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -164,6 +167,22 @@ class SparkContext( } scheduler + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + scheduler.initialize(backend) + scheduler + case _ => if (MESOS_REGEX.findFirstIn(master).isEmpty) { logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) @@ -183,7 +202,7 @@ class SparkContext( } taskScheduler.start() - private var dagScheduler = new DAGScheduler(taskScheduler) + @volatile private var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -207,6 +226,9 @@ class SparkContext( private[spark] var checkpointDir: Option[String] = None + // Post init + taskScheduler.postStartHook() + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -471,7 +493,7 @@ class SparkContext( */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.ip + ":" + blockManagerId.port, mem) + (blockManagerId.host + ":" + blockManagerId.port, mem) } } @@ -527,10 +549,13 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - if (dagScheduler != null) { + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { metadataCleaner.cancel() - dagScheduler.stop() - dagScheduler = null + dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -546,6 +571,7 @@ class SparkContext( } } + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7157fd2688..ffb40bab3a 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -72,6 +72,16 @@ object SparkEnv extends Logging { System.setProperty("spark.driver.port", boundPort.toString) } + // set only if unset until now. + if (System.getProperty("spark.hostPort", null) == null) { + if (!isDriver){ + // unexpected + Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") + } + Utils.checkHost(hostname) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + } + val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -88,9 +98,10 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name) + Utils.checkHost(driverHost, "Expected hostname") + val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 81daacf958..14bb153d54 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,18 @@ package spark import java.io._ -import java.net._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder -import scala.Some import spark.serializer.SerializerInstance +import spark.deploy.SparkHadoopUtil /** * Various utility methods used by Spark. @@ -68,6 +68,41 @@ private object Utils extends Logging { return buf } + + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; else false + // This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException + // and incomplete cleanup + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + + val absolutePath = file.getAbsolutePath() + + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined + } + + if (retval) logInfo("path = " + file + ", already present as root for deletion.") + + retval + } + /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 @@ -86,10 +121,14 @@ private object Utils extends Logging { } } catch { case e: IOException => ; } } + + registerShutdownDeleteDir(dir) + // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { - Utils.deleteRecursively(dir) + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) return dir @@ -227,8 +266,10 @@ private object Utils extends Logging { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. */ lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) private def findLocalIpAddress(): String = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") @@ -266,6 +307,8 @@ private object Utils extends Logging { * hostname it reports to the master. */ def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) customHostname = Some(hostname) } @@ -273,7 +316,90 @@ private object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + // customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + + // Used by DEBUG code : remove when all testing done + def checkHost(host: String, message: String = "") { + // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! + if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) + } + if (Utils.parseHostPort(host)._2 != 0){ + Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message) + } + } + + // Used by DEBUG code : remove when all testing done + def checkHostPort(hostPort: String, message: String = "") { + val (host, port) = Utils.parseHostPort(hostPort) + checkHost(host) + if (port <= 0){ + Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) + } + } + + def getUserNameFromEnvironment(): String = { + SparkHadoopUtil.getUserNameFromEnvironment + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + // temp code for debug + System.exit(-1) + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) + } + + def addIfNoPort(hostPort: String, port: Int): String = { + if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port) + + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + val indx: Int = hostPort.lastIndexOf(':') + if (-1 != indx) return hostPort + + hostPort + ":" + port } private[spark] val daemonThreadFactory: ThreadFactory = diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 9b4d54ab4e..807119ca8c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -277,6 +277,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte] */ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { + + Utils.checkHost(serverHost, "Expected hostname") override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 8a3e64e4c2..51274acb1e 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, ApplicationInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List +import spark.Utils private[spark] sealed trait DeployMessage extends Serializable @@ -19,7 +20,10 @@ case class RegisterWorker( memory: Int, webUiPort: Int, publicAddress: String) - extends DeployMessage + extends DeployMessage { + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} private[spark] case class ExecutorStateChanged( @@ -58,7 +62,9 @@ private[spark] case class RegisteredApplication(appId: String) extends DeployMessage private[spark] -case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) +case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { + Utils.checkHostPort(hostPort, "Required hostport") +} private[spark] case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -81,6 +87,9 @@ private[spark] case class MasterState(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + Utils.checkHost(host, "Required hostname") + assert (port > 0) + def uri = "spark://" + host + ":" + port } @@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState private[spark] case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, - coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) + coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 38a6ebfc24..71a641a9ef 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), "host" -> JsString(obj.host), + "port" -> JsNumber(obj.port), "webuiaddress" -> JsString(obj.webUiAddress), "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 22319a96ca..55bb61b0cc 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - private val localIpAddress = Utils.localIpAddress + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() @@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) masterActorSystems += masterSystem - val masterUrl = "spark://" + localIpAddress + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterPort /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masterUrl, null, Some(workerNum)) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 2fc5e657f9..072232e33a 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -59,10 +59,10 @@ private[spark] class Client( markDisconnected() context.stop(self) - case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) => + case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) - listener.executorAdded(fullId, workerId, host, cores, memory) + logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) + listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => val fullId = appId + "/" + id diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index b7008321df..e8c4083f9d 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc004b59ca..ad92532b58 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -16,7 +16,7 @@ private[spark] object TestClient { System.exit(0) } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {} + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 71b9d0801d..160afe5239 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils} import spark.util.AkkaUtils -private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 @@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor var firstApp: Option[ApplicationInfo] = None + Utils.checkHost(host, "Expected hostname") + val masterPublicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } // As a temporary workaround before better ways of configuring memory, we allow users to set @@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean override def preStart() { - logInfo("Starting Spark master at spark://" + ip + ":" + port) + logInfo("Starting Spark master at spark://" + host + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() @@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } case RequestMasterState => { - sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray) + sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray) } } @@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. - workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) + workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -307,7 +309,7 @@ private[spark] object Master { def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 4ceab3fc03..3d28ecabb4 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -7,13 +7,13 @@ import spark.Utils * Command-line parser for the master. */ private[spark] class MasterArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_IP") != null) { - ip = System.getenv("SPARK_MASTER_IP") + if (System.getenv("SPARK_MASTER_HOST") != null) { + host = System.getenv("SPARK_MASTER_HOST") } if (System.getenv("SPARK_MASTER_PORT") != null) { port = System.getenv("SPARK_MASTER_PORT").toInt @@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) { "Usage: Master [options]\n" + "\n" + "Options:\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 23df1bb463..0c08c5f417 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -2,6 +2,7 @@ package spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable +import spark.Utils private[spark] class WorkerInfo( val id: String, @@ -13,6 +14,9 @@ private[spark] class WorkerInfo( val webUiPort: Int, val publicAddress: String) { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 @@ -23,6 +27,11 @@ private[spark] class WorkerInfo( def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed + def hostPort: String = { + assert (port > 0) + host + ":" + port + } + def addExecutor(exec: ExecutorInfo) { executors(exec.fullId) = exec coresUsed += exec.cores diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index de11771c8e..dfcb9f0d05 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -21,11 +21,13 @@ private[spark] class ExecutorRunner( val memory: Int, val worker: ActorRef, val workerId: String, - val hostname: String, + val hostPort: String, val sparkHome: File, val workDir: File) extends Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null @@ -68,7 +70,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => hostname + case "{{HOSTPORT}}" => hostPort case "{{CORES}}" => cores.toString case other => other } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8919d1261c..cf4babc892 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -16,7 +16,7 @@ import spark.deploy.master.Master import java.io.File private[spark] class Worker( - ip: String, + host: String, port: Int, webUiPort: Int, cores: Int, @@ -25,6 +25,9 @@ private[spark] class Worker( workDirPath: String = null) extends Actor with Logging { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds @@ -39,7 +42,7 @@ private[spark] class Worker( val finishedExecutors = new HashMap[String, ExecutorRunner] val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } var coresUsed = 0 @@ -64,7 +67,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - ip, port, cores, Utils.memoryMegabytesToString(memory))) + host, port, cores, Utils.memoryMegabytesToString(memory))) sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() @@ -75,7 +78,7 @@ private[spark] class Worker( def connectToMaster() { logInfo("Connecting to master " + masterUrl) master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } @@ -106,7 +109,7 @@ private[spark] class Worker( case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) + appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -141,7 +144,7 @@ private[spark] class Worker( masterDisconnected() case RequestWorkerState => { - sender ! WorkerState(ip, port, workerId, executors.values.toList, + sender ! WorkerState(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, masterUrl, cores, memory, coresUsed, memoryUsed, masterWebUiUrl) } @@ -156,7 +159,7 @@ private[spark] class Worker( } def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) + "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) } override def postStop() { @@ -167,7 +170,7 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.master, args.workDir) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 08f02bad80..2b96611ee3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory * Command-line parser for the master. */ private[spark] class WorkerArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() @@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) { " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 3e7407b58d..344face5e6 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging { - + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert initLogging() + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 1047f71c6a..49e1f3f07a 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor +import spark.Utils +import spark.deploy.SparkHadoopUtil private[spark] class StandaloneExecutorBackend( driverUrl: String, executorId: String, - hostname: String, + hostPort: String, cores: Int) extends Actor with ExecutorBackend with Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + var executor: Executor = null var driver: ActorRef = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) + driver ! RegisterExecutor(executorId, hostPort, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(driver) // Doesn't work with remote actors, but useful for testing } @@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") - executor = new Executor(executorId, hostname, sparkProperties) + // Make this host instead of hostPort ? + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -63,11 +68,29 @@ private[spark] class StandaloneExecutorBackend( private[spark] object StandaloneExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores)) + } + + // This will be run 'as' the user + def run0(args: Product) { + assert(4 == args.productArity) + runImpl(args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[String], + args.productElement(0).asInstanceOf[Int]) + } + + private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + // Debug code + Utils.checkHost(hostname) + // set it + val sparkHostPort = hostname + ":" + boundPort + System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)), + Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index d1451bc212..00a0433a44 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,7 +13,7 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( @@ -32,16 +32,43 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() + + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + // What is the interest to register with selector for when we want this connection to be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be + // SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } def key() = channel.keyFor(selector) def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - def read() { + // Returns whether we have to register for further reads or not. + def read(): Boolean = { throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) } - - def write() { + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) } @@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onExceptionCallback != null) { onExceptionCallback(this, e) } else { - logError("Error in connection to " + remoteConnectionManagerId + + logError("Error in connection to " + getRemoteConnectionManagerId() + " and OnExceptionCallback not registered", e) } } @@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onCloseCallback != null) { onCloseCallback(this) } else { - logWarning("Connection to " + remoteConnectionManagerId + + logWarning("Connection to " + getRemoteConnectionManagerId() + " closed and OnExceptionCallback not registered") } @@ -81,7 +108,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def changeConnectionKeyInterest(ops: Int) { if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) + onKeyInterestChangeCallback(this, ops) } else { throw new Exception("OnKeyInterestChangeCallback not registered") } @@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") } } @@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } return chunk } else { - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") + logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") + logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") return chunk } else { message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - val outbox = new Outbox(1) + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ - override def getRemoteAddress() = address + override def getRemoteAddress() = address + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + registerInterest() } } } + // MUST be called within the selector loop def connect() { try{ - channel.connect(address) channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { case e: Exception => { @@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - def finishConnect() { + def finishConnect(force: Boolean): Boolean = { try { - channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + return true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) + // ignore + return true } } } - override def write() { + override def write(): Boolean = { try{ while(true) { if (currentBuffers.size == 0) { @@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(SelectionKey.OP_READ) - return + // changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return false } } } @@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers -= buffer } if (writtenBytes < remainingBytes) { - return + // re-register for write. + return true } } } } catch { case e: Exception => { - logWarning("Error writing in connection to " + remoteConnectionManagerId, e) + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } - override def read() { + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { // We don't expect the other side to send anything; so, we just read to detect an error or EOF. try { val length = channel.read(ByteBuffer.allocate(1)) if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => - logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() } + + false } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = true } +// Must be created within selector loop - else deadlock private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) extends Connection(channel_, selector_) { @@ -298,13 +367,13 @@ extends Connection(channel_, selector_) { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") + logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") + logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } @@ -316,7 +385,27 @@ extends Connection(channel_, selector_) { messages -= message.id } } - + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) var onReceiveCallback: (Connection , Message) => Unit = null @@ -324,17 +413,18 @@ extends Connection(channel_, selector_) { channel.register(selector, SelectionKey.OP_READ) - override def read() { + override def read(): Boolean = { try { while (true) { if (currentChunk == null) { val headerBytesRead = channel.read(headerBuffer) if (headerBytesRead == -1) { close() - return + return false } if (headerBuffer.remaining > 0) { - return + // re-register for read event ... + return true } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { @@ -342,6 +432,9 @@ extends Connection(channel_, selector_) { } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() + + processConnectionManagerId(header) + header.typ match { case Message.BUFFER_MESSAGE => { if (header.totalSize == 0) { @@ -349,7 +442,8 @@ extends Connection(channel_, selector_) { onReceiveCallback(this, Message.create(header)) } currentChunk = null - return + // re-register for read event ... + return true } else { currentChunk = inbox.getChunk(header).orNull } @@ -362,10 +456,11 @@ extends Connection(channel_, selector_) { val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { - return + // re-register for read event ... + return true } else if (bytesRead == -1) { close() - return + return false } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ @@ -376,7 +471,7 @@ extends Connection(channel_, selector_) { if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -387,12 +482,31 @@ extends Connection(channel_, selector_) { } } catch { case e: Exception => { - logWarning("Error reading from connection to " + remoteConnectionManagerId, e) + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index b6ec664d7e..0c6bdb1559 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -6,12 +6,12 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ -import java.util.concurrent.Executors +import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} +import scala.collection.mutable.HashSet import scala.collection.mutable.HashMap import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} @@ -19,6 +19,10 @@ import akka.util.Duration import akka.util.duration._ private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + def toSocketAddress() = new InetSocketAddress(host, port) } @@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) - val serverChannel = ServerSocketChannel.open() - val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - val sendMessageRequests = new Queue[(Message, SendingConnection)] + private val selector = SelectorProvider.provider.openSelector() + + private val handleMessageExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.handler.threads.min","20").toInt, + System.getProperty("spark.core.connection.handler.threads.max","60").toInt, + System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.io.threads.min","4").toInt, + System.getProperty("spark.core.connection.io.threads.max","32").toInt, + System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.connect.threads.min","1").toInt, + System.getProperty("spark.core.connection.connect.threads.max","8").toInt, + System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val serverChannel = ServerSocketChannel.open() + private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) @@ -65,46 +87,139 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - val selectorThread = new Thread("connection-manager-thread") { + + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) selectorThread.start() - private def run() { - try { - while(!selectorThread.isInterrupted) { - for ((connectionManagerId, sendingConnection) <- connectionRequests) { - sendingConnection.connect() - addConnection(sendingConnection) - connectionRequests -= connectionManagerId + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + if (register && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } } - sendMessageRequests.synchronized { - while (!sendMessageRequests.isEmpty) { - val (message, connection) = sendMessageRequests.dequeue - connection.send(message) + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - while (!keyInterestChangeRequests.isEmpty) { + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need not + // succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + def run() { + try { + while(!selectorThread.isInterrupted) { + while (! registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey(key) - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } } - - logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } val selectedKeysCount = selector.select() @@ -123,12 +238,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (key.isValid) { if (key.isAcceptable) { acceptConnection(key) - } else if (key.isConnectable) { - connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else if (key.isReadable) { - connectionsByKey(key).read() - } else if (key.isWritable) { - connectionsByKey(key).write() + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) } } } @@ -138,94 +256,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - private def acceptConnection(key: SelectionKey) { + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val newChannel = serverChannel.accept() - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - newConnection.onClose(removeConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } - private def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() } + } + + private def addListeners(connection: Connection) { connection.onKeyInterestChange(changeConnectionKeyInterest) connection.onException(handleConnectionError) connection.onClose(removeConnection) } - private def removeConnection(connection: Connection) { + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { connectionsByKey -= connection.key - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + + try { + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (! sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull - if (sendingConnectionManagerId == null) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) - - val sendingConnection = connectionsById(sendingConnectionManagerId) - sendingConnection.close() - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert (sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } } + } finally { + // So that the selection keys can be removed. + wakeupSelector() } } - private def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) } - private def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() } - private def receiveMessage(connection: Connection, message: Message) { + def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { @@ -293,18 +433,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, - new SendingConnection(inetSocketAddress, selector, connectionManagerId)) - newConnection + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + registerRequests.enqueue(newConnection) + + newConnection } - val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) - val connection = connectionsById.getOrElse(lookupKey, startNewConnection()) + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - /*connection.send(message)*/ - sendMessageRequests.synchronized { - sendMessageRequests += ((message, connection)) - } + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { selector.wakeup() } @@ -337,6 +481,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logWarning("All connections not cleaned up") } handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() logInfo("ConnectionManager stopped") } } diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 525751b5bf..34fac9e776 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -17,7 +17,8 @@ private[spark] class MessageChunkHeader( val other: Int, val address: InetSocketAddress) { lazy val buffer = { - val ip = address.getAddress.getAddress() + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection + val ip = address.getAddress.getAddress() val port = address.getPort() ByteBuffer. allocate(MessageChunkHeader.HEADER_SIZE). diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index c54dce51d7..1440b93e65 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -50,6 +50,11 @@ class DAGScheduler( eventQueue.put(ExecutorLost(execId)) } + // Called by TaskScheduler when a host is added + override def executorGained(execId: String, hostPort: String) { + eventQueue.put(ExecutorGained(execId, hostPort)) + } + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) @@ -113,7 +118,7 @@ class DAGScheduler( if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { - locations => locations.map(_.ip).toList + locations => locations.map(_.hostPort).toList }.toArray } cacheLocs(rdd.id) @@ -293,6 +298,9 @@ class DAGScheduler( submitStage(finalStage) } + case ExecutorGained(execId, hostPort) => + handleExecutorGained(execId, hostPort) + case ExecutorLost(execId) => handleExecutorLost(execId) @@ -630,6 +638,14 @@ class DAGScheduler( "(generation " + currentGeneration + ")") } } + + private def handleExecutorGained(execId: String, hostPort: String) { + // remove from failedGeneration(execId) ? + if (failedGeneration.contains(execId)) { + logInfo("Host gained which was in lost list earlier: " + hostPort) + failedGeneration -= execId + } + } /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178..b46bb863f0 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -32,6 +32,10 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent +private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent { + Utils.checkHostPort(hostPort, "Required hostport") +} + private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala new file mode 100644 index 0000000000..287f731787 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala @@ -0,0 +1,156 @@ +package spark.scheduler + +import spark.Logging +import scala.collection.immutable.Set +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.conf.Configuration +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ + + +/** + * Parses and holds information about inputFormat (and files) specified as a parameter. + */ +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], + val path: String) extends Logging { + + var mapreduceInputFormat: Boolean = false + var mapredInputFormat: Boolean = false + + validate() + + override def toString(): String = { + "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + path.hashCode + hashCode + } + + // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path + // .. which is fine, this is best case effort to remove duplicates - right ? + override def equals(other: Any): Boolean = other match { + case that: InputFormatInfo => { + // not checking config - that should be fine, right ? + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path + } + case _ => false + } + + private def validate() { + logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) + + try { + if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapreduce package") + mapreduceInputFormat = true + } + else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapred package") + mapredInputFormat = true + } + else { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + + " is NOT a supported input format ? does not implement either of the supported hadoop api's") + } + } + catch { + case e: ClassNotFoundException => { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) + } + } + } + + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { + val conf = new JobConf(configuration) + FileInputFormat.setInputPaths(conf, path) + + val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ + org.apache.hadoop.mapreduce.InputFormat[_, _]] + val job = new Job(conf) + + val retval = new ArrayBuffer[SplitInfo]() + val list = instance.getSplits(job) + for (split <- list) { + retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) + } + + return retval.toSet + } + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { + val jobConf = new JobConf(configuration) + FileInputFormat.setInputPaths(jobConf, path) + + val instance: org.apache.hadoop.mapred.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ + org.apache.hadoop.mapred.InputFormat[_, _]] + + val retval = new ArrayBuffer[SplitInfo]() + instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( + elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) + ) + + return retval.toSet + } + + private def findPreferredLocations(): Set[SplitInfo] = { + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + ", inputFormatClazz : " + inputFormatClazz) + if (mapreduceInputFormat) { + return prefLocsFromMapreduceInputFormat() + } + else { + assert(mapredInputFormat) + return prefLocsFromMapredInputFormat() + } + } +} + + + + +object InputFormatInfo { + /** + Computes the preferred locations based on input(s) and returned a location to block map. + Typical use of this method for allocation would follow some algo like this + (which is what we currently do in YARN branch) : + a) For each host, count number of splits hosted on that host. + b) Decrement the currently allocated containers on that host. + c) Compute rack info for each host and update rack -> count map based on (b). + d) Allocate nodes based on (c) + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single + (or small set of) hosts go down. + + go to (a) until required nodes are allocated. + + If a node 'dies', follow same procedure. + + PS: I know the wording here is weird, hopefully it makes some sense ! + */ + def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { + + val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] + for (inputSplit <- formats) { + val splits = inputSplit.findPreferredLocations() + + for (split <- splits){ + val location = split.hostLocation + val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) + set += split + } + } + + nodeToSplit + } +} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index beb21a76fe..89dc6640b2 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U]( rdd.partitions(partition) } + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) metrics = Some(context.taskMetrics) @@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U]( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partition + ")" diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d0..7dc6da4573 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -77,13 +77,21 @@ private[spark] class ShuffleMapTask( var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, - @transient var locs: Seq[String]) + @transient private var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { protected def this() = this(0, null, null, 0, null) + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + var split = if (rdd == null) { null } else { @@ -154,7 +162,7 @@ private[spark] class ShuffleMapTask( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) } diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala new file mode 100644 index 0000000000..6abfb7a1f7 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala @@ -0,0 +1,61 @@ +package spark.scheduler + +import collection.mutable.ArrayBuffer + +// information about a specific split instance : handles both split instances. +// So that we do not need to worry about the differences. +class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, + val length: Long, val underlyingSplit: Any) { + override def toString(): String = { + "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + + ", hostLocation : " + hostLocation + ", path : " + path + + ", length : " + length + ", underlyingSplit " + underlyingSplit + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + hostLocation.hashCode + hashCode = hashCode * 31 + path.hashCode + // ignore overflow ? It is hashcode anyway ! + hashCode = hashCode * 31 + (length & 0x7fffffff).toInt + hashCode + } + + // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // So unless there is identity equality between underlyingSplits, it will always fail even if it + // is pointing to same block. + override def equals(other: Any): Boolean = other match { + case that: SplitInfo => { + this.hostLocation == that.hostLocation && + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path && + this.length == that.length && + // other split specific checks (like start for FileSplit) + this.underlyingSplit == that.underlyingSplit + } + case _ => false + } +} + +object SplitInfo { + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapredSplit.getLength + for (host <- mapredSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) + } + retval + } + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapreduceSplit.getLength + for (host <- mapreduceSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) + } + retval + } +} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index d549b184b0..7787b54762 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -10,6 +10,10 @@ package spark.scheduler private[spark] trait TaskScheduler { def start(): Unit + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + // Disconnect from the cluster. def stop(): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index 771518dddf..b75d3736cf 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit + // A node was added to the cluster. + def executorGained(execId: String, hostPort: String): Unit + // A node was lost from the cluster. def executorLost(execId: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 26fdef101b..2e18d46edc 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import java.io.{File, FileInputStream, FileOutputStream} +import java.lang.{Boolean => JBoolean} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -25,6 +25,30 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + // How often to revive offers in case there are pending tasks - that is how often to try to get + // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior + // Note that this is required due to delayed scheduling due to data locality waits, etc. + // TODO: rename property ? + val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong + + /* + This property controls how aggressive we should be to modulate waiting for host local task scheduling. + To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before + scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : + host-local, rack-local and then others + But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before + scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can + modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is + maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. + + TODO: rename property ? The value is one of + - HOST_LOCAL (default, no change w.r.t current behavior), + - RACK_LOCAL and + - ANY + + Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. + */ + val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL")) val activeTaskSets = new HashMap[String, TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -33,9 +57,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] - var hasReceivedTask = false - var hasLaunchedTask = false - val starvationTimer = new Timer(true) + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) @@ -43,11 +67,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Which executor IDs we have executors on val activeExecutorIds = new HashSet[String] + // TODO: We might want to remove this and merge it with execId datastructures - but later. + // Which hosts in the cluster are alive (contains hostPort's) + private val hostPortsAlive = new HashSet[String] + private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] + // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - val executorsByHost = new HashMap[String, HashSet[String]] + val executorsByHostPort = new HashMap[String, HashSet[String]] - val executorIdToHost = new HashMap[String, String] + val executorIdToHostPort = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -75,11 +104,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (System.getProperty("spark.speculation", "false") == "true") { + if (JBoolean.getBoolean("spark.speculation")) { new Thread("ClusterScheduler speculation check") { setDaemon(true) override def run() { + logInfo("Starting speculative execution thread") while (true) { try { Thread.sleep(SPECULATION_INTERVAL) @@ -91,6 +121,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } }.start() } + + + // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ? + if (TASK_REVIVAL_INTERVAL > 0) { + new Thread("ClusterScheduler task offer revival check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative task offer revival thread") + while (true) { + try { + Thread.sleep(TASK_REVIVAL_INTERVAL) + } catch { + case e: InterruptedException => {} + } + + if (hasPendingTasks()) backend.reviveOffers() + } + } + }.start() + } } override def submitTasks(taskSet: TaskSet) { @@ -139,22 +190,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - executorIdToHost(o.executorId) = o.hostname - if (!executorsByHost.contains(o.hostname)) { - executorsByHost(o.hostname) = new HashSet() + // DEBUG Code + Utils.checkHostPort(o.hostPort) + + executorIdToHostPort(o.executorId) = o.hostPort + if (! executorsByHostPort.contains(o.hostPort)) { + executorsByHostPort(o.hostPort) = new HashSet[String]() } + + hostPortsAlive += o.hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort) + executorGained(o.executorId, o.hostPort) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray var launchedTask = false + + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + + // Split offers based on host local, rack local and off-rack tasks. + val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val otherOffers = new HashMap[String, ArrayBuffer[Int]]() + + for (i <- 0 until offers.size) { + val hostPort = offers(i).hostPort + // DEBUG code + Utils.checkHostPort(hostPort) + val host = Utils.parseHostPort(hostPort)._1 + val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i))) + if (numHostLocalTasks > 0){ + val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numHostLocalTasks) list += i + } + + val numRackLocalTasks = math.max(0, + // Remove host local tasks (which are also rack local btw !) from this + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i))) + if (numRackLocalTasks > 0){ + val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numRackLocalTasks) list += i + } + if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){ + // add to others list - spread even this across cluster. + val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + list += i + } + } + + val offersPriorityList = new ArrayBuffer[Int]( + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + // First host local, then rack, then others + val numHostLocalOffers = { + val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) + offersPriorityList ++= hostLocalPriorityList + hostLocalPriorityList.size + } + val numRackLocalOffers = { + val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) + offersPriorityList ++= rackLocalPriorityList + rackLocalPriorityList.size + } + offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers) + + var lastLoop = false + val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match { + case TaskLocality.HOST_LOCAL => numHostLocalOffers + case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers + case TaskLocality.ANY => offersPriorityList.size + } + do { launchedTask = false - for (i <- 0 until offers.size) { + var loopCount = 0 + for (i <- offersPriorityList) { val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { + val hostPort = offers(i).hostPort + + // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing) + val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null + + // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ... + loopCount += 1 + + manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match { case Some(task) => tasks(i) += task val tid = task.taskId @@ -162,15 +283,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetTaskIds(manager.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId - executorsByHost(host) += execId + executorsByHostPort(hostPort) += execId availableCpus(i) -= 1 launchedTask = true - + case None => {} } } + // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of + // data locality (we still go in order of priority : but that would not change anything since + // if data local tasks had been available, we would have scheduled them already) + if (lastLoop) { + // prevent more looping + launchedTask = false + } else if (!lastLoop && !launchedTask) { + // Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL + if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) { + // fudge launchedTask to ensure we loop once more + launchedTask = true + // dont loop anymore + lastLoop = true + } + } } while (launchedTask) } + if (tasks.size > 0) { hasLaunchedTask = true } @@ -256,10 +393,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (jarServer != null) { jarServer.stop() } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) } override def defaultParallelism() = backend.defaultParallelism() + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false @@ -273,12 +415,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + // Check for pending tasks in all our active jobs. + def hasPendingTasks(): Boolean = { + synchronized { + activeTaskSetsQueue.exists( _.hasPendingTasks() ) + } + } + def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None + synchronized { if (activeExecutorIds.contains(executorId)) { - val host = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + val hostPort = executorIdToHostPort(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId) failedExecutor = Some(executorId) } else { @@ -296,19 +446,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - /** Get a list of hosts that currently have executors */ - def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet - /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val hostPort = executorIdToHostPort(executorId) + if (hostPortsAlive.contains(hostPort)) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + hostPortsAlive -= hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort) + } + + val execs = executorsByHostPort.getOrElse(hostPort, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + executorsByHostPort -= hostPort } - executorIdToHost -= executorId - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + executorIdToHostPort -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort)) + } + + def executorGained(execId: String, hostPort: String) { + listener.executorGained(execId, hostPort) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = { + val retval = hostToAliveHostPorts.get(host) + if (retval.isDefined) { + return Some(retval.get.toSet) + } + + None + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None + + // By default, (cached) hosts for rack is unknown + def getCachedHostsForRack(rack: String): Option[Set[String]] = None +} + + +object ClusterScheduler { + + // Used to 'spray' available containers across the available set to ensure too many containers on same host + // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available + // to execute a task) + // For example: yarn can returns more containers than we would have requested under ANY, this method + // prioritizes how to use the allocated containers. + // flatten the map such that the array buffer entries are spread out across the returned value. + // given == , , , , , i + // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5 + // We then 'use' the containers in this order (consuming only the top K from this list where + // K = number to be user). This is to ensure that if we have multiple eligible allocations, + // they dont end up allocating all containers on a small number of hosts - increasing probability of + // multiple container failure when a host goes down. + // Note, there is bias for keys with higher number of entries in value to be picked first (by design) + // Also note that invocation of this method is expected to have containers of same 'type' + // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from + // the available list - everything else being same. + // That is, we we first consume data local, then rack local and finally off rack nodes. So the + // prioritization from this method applies to within each category + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found){ + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bb289c9cf3..6b61152ed0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend( val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTPORT}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) @@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend( } } - override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - executorId, host, cores, Utils.memoryMegabytesToString(memory))) + override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + executorId, hostPort, cores, Utils.memoryMegabytesToString(memory))) } override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index d766067824..3335294844 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.TaskState.TaskState import java.nio.ByteBuffer import spark.util.SerializableBuffer +import spark.Utils private[spark] sealed trait StandaloneClusterMessage extends Serializable @@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess // Executors to driver private[spark] -case class RegisterExecutor(executorId: String, host: String, cores: Int) - extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + extends StandaloneClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") +} private[spark] case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7a428e3361..c20276a605 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import akka.actor._ import akka.util.duration._ import akka.pattern.ask +import akka.util.Duration -import spark.{SparkException, Logging, TaskState} +import spark.{Utils, SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} @@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - val executorActor = new HashMap[String, ActorRef] - val executorAddress = new HashMap[String, Address] - val executorHost = new HashMap[String, String] - val freeCores = new HashMap[String, Int] - val actorToExecutorId = new HashMap[ActorRef, String] - val addressToExecutorId = new HashMap[Address, String] + private val executorActor = new HashMap[String, ActorRef] + private val executorAddress = new HashMap[String, Address] + private val executorHostPort = new HashMap[String, String] + private val freeCores = new HashMap[String, Int] + private val actorToExecutorId = new HashMap[ActorRef, String] + private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterExecutor(executorId, host, cores) => + case RegisterExecutor(executorId, hostPort, cores) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { @@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! RegisteredExecutor(sparkProperties) context.watch(sender) executorActor(executorId) = sender - executorHost(executorId) = host + executorHostPort(executorId) = hostPort freeCores(executorId) = cores executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId @@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))})) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers @@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId - executorHost -= executorId + executorHostPort -= executorId freeCores -= executorId - executorHost -= executorId + executorHostPort -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } @@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor while (iterator.hasNext) { val entry = iterator.next val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { + if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } } @@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + override def stop() { try { if (driverActor != null) { - val timeout = 5.seconds val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index dfe3c5a85b..718f26bfbd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * Information about a running task attempt inside a TaskSet. */ @@ -9,8 +11,11 @@ class TaskInfo( val index: Int, val launchTime: Long, val executorId: String, - val host: String, - val preferred: Boolean) { + val hostPort: String, + val taskLocality: TaskLocality.TaskLocality) { + + Utils.checkHostPort(hostPort, "Expected hostport") + var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c9f2c48804..27e713e2c4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,7 +1,6 @@ package spark.scheduler.cluster -import java.util.Arrays -import java.util.{HashMap => JHashMap} +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -14,6 +13,36 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer +private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + val HOST_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + constraint match { + case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + TaskLocality.withName(str) + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL"); + // default to preserve earlier behavior + HOST_LOCAL + } + } + } +} + /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ @@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node. These collections are actually + // List of pending tasks for each node (hyper local to container). These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put // back at the head of the stack. They are also only cleaned up lazily; // when a task is launched, it remains in all the pending lists except // the one that it was launched from, but gets removed from them later. - val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] // List containing pending tasks with no locality preferences val pendingTasksWithNoPrefs = new ArrayBuffer[Int] @@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = { + // DEBUG code + _taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations)) + + val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else { + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new ArrayBuffer[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (locations.size == 0) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - for (host <- locations) { - val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + + // host locality + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) list += index } } + allPendingTasks += index } + // Return the pending tasks list for a given host port (hyper local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 pendingTasksForHost.getOrElse(host, ArrayBuffer()) } + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Dequeue a pending task from the given list and return its index. // Return None if the list is empty. // This method also cleans up any tasks in the list that have already @@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the - // task must have a preference for this host (or no preferred locations at all). - private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { - val hostsAlive = sched.hostsAlive + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL)) speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - val localTask = speculatableTasks.find { - index => - val locations = tasks(index).preferredLocations.toSet & hostsAlive - val attemptLocs = taskAttempts(index).map(_.host) - (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask } - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - if (!localOnly && speculatableTasks.size > 0) { - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } } } return None @@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - private def findTask(host: String, localOnly: Boolean): Option[Int] = { - val localTask = findTaskFromList(getPendingTasksForHost(host)) + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) if (localTask != None) { return localTask } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) if (noPrefTask != None) { return noPrefTask } - if (!localOnly) { + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { val nonLocalTask = findTaskFromList(allPendingTasks) if (nonLocalTask != None) { return nonLocalTask } } + // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(host, localOnly) + return findSpeculativeTask(hostPort, locality) } // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - private def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = { val locs = task.preferredLocations - return (locs.contains(host) || locs.isEmpty) + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + if (locs.contains(hostPort) || locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.contains(host) + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val time = System.currentTimeMillis - val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY + } - findTask(host, localOnly) match { + findTask(hostPort, locality) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it val task = tasks(index) val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch - val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) { - "preferred" - } else { - "non-preferred, not one of " + task.preferredLocations.mkString(", ") - } - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, prefStr)) + val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, execId, host, preferred) + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) - if (preferred) { + if (TaskLocality.HOST_LOCAL == taskLocality) { lastPreferredLaunchTime = time } // Serialize and return the task @@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe sched.taskSetFinished(this) } - def executorLost(execId: String, hostname: String) { + def executorLost(execId: String, hostPort: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - val newHostsAlive = sched.hostsAlive // If some task has preferred locations only on hostname, and there are no more executors there, // put it in the no-prefs list to avoid the wait from delay scheduling - if (!newHostsAlive.contains(hostname)) { - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } + for (index <- getPendingTasksForHostPort(hostPort)) { + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true) + if (newLocs.isEmpty) { + assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty) + pendingTasksWithNoPrefs += index } } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage @@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe !speculatableTasks.contains(index)) { logInfo( "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) + taskSet.id, index, info.hostPort, threshold)) speculatableTasks += index foundTasks = true } @@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } return foundTasks } + + def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 3c3afcbb14..c47824315c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -4,5 +4,5 @@ package spark.scheduler.cluster * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9e1bde3fbe..f060a940a9 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap import spark._ import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo +import spark.scheduler.cluster.{TaskLocality, TaskInfo} /** * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon def runTask(task: Task[_], idInJob: Int, attemptId: Int) { logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true) + val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..10e70723db 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -37,17 +37,27 @@ class BlockManager( maxMemory: Long) extends Logging { - class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - var pending: Boolean = true - var size: Long = -1L - var failed: Boolean = false + private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + @volatile var pending: Boolean = true + @volatile var size: Long = -1L + @volatile var initThread: Thread = null + @volatile var failed = false + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + this.initThread = Thread.currentThread() + } /** * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). * Return true if the block is available, false otherwise. */ def waitForReady(): Boolean = { - if (pending) { + if (initThread != Thread.currentThread() && pending) { synchronized { while (pending) this.wait() } @@ -57,19 +67,26 @@ class BlockManager( /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { + assert (pending) + size = sizeInBytes + initThread = null + failed = false + initThread = null + pending = false synchronized { - pending = false - failed = false - size = sizeInBytes this.notifyAll() } } /** Mark this BlockInfo as ready but failed */ def markFailure() { + assert (pending) + size = 0 + initThread = null + failed = true + initThread = null + pending = false synchronized { - failed = true - pending = false this.notifyAll() } } @@ -101,7 +118,7 @@ class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - val host = System.getProperty("spark.hostname", Utils.localHostName()) + val hostPort = Utils.localHostPort() val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -212,9 +229,12 @@ class BlockManager( * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. + * + * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). + * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo) { - val needReregister = !tryToReportBlockStatus(blockId, info) + def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -228,7 +248,7 @@ class BlockManager( * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -237,7 +257,7 @@ class BlockManager( val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) } @@ -257,7 +277,7 @@ class BlockManager( def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) - val locations = managers.map(_.ip) + val locations = managers.map(_.hostPort) logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -267,7 +287,7 @@ class BlockManager( */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -339,6 +359,8 @@ class BlockManager( case Some(bytes) => // Put a copy of the block back in memory before returning it. Note that we can't // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + // The use of rewind assumes this. + assert (0 == bytes.position()) val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -411,6 +433,7 @@ class BlockManager( // Read it as a byte buffer into memory first, then return it diskStore.getBytes(blockId) match { case Some(bytes) => + assert (0 == bytes.position()) if (level.useMemory) { if (level.deserialized) { memoryStore.putBytes(blockId, bytes, level) @@ -450,7 +473,7 @@ class BlockManager( for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) + GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { return Some(dataDeserialize(blockId, data)) } @@ -501,17 +524,17 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId).orNull - if (oldBlock != null && oldBlock.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlock.size - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, myInfo) + + if (oldBlockOpt.isDefined && oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } val startTimeMs = System.currentTimeMillis @@ -531,6 +554,7 @@ class BlockManager( logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Save it just to memory first, even if it also has useDisk set to true; we will later @@ -555,20 +579,20 @@ class BlockManager( // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(size) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -611,16 +635,17 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.contains(blockId)) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + // Do atomically ! + val prevInfo = blockInfo.putIfAbsent(blockId, myInfo) + if (prevInfo != null) { + // Should we check for prevInfo.waitForReady() here ? + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } val startTimeMs = System.currentTimeMillis @@ -639,6 +664,7 @@ class BlockManager( logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Store it only in memory at first, even if useDisk is also set to true @@ -649,22 +675,24 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } + // assert (0 == bytes.position(), "" + bytes) + // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(bytes.limit) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -698,7 +726,7 @@ class BlockManager( logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.ip, peer.port))) { + new ConnectionManagerId(peer.host, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + @@ -730,6 +758,14 @@ class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (! info.waitForReady() ) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + return + } + val level = info.level if (level.useDisk && !diskStore.contains(blockId)) { logInfo("Writing block " + blockId + " to disk") @@ -740,12 +776,13 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } + val droppedMemorySize = memoryStore.getSize(blockId) val blockWasRemoved = memoryStore.remove(blockId) if (!blockWasRemoved) { logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") } if (info.tellMaster) { - reportBlockStatus(blockId, info) + reportBlockStatus(blockId, info, droppedMemorySize) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -938,8 +975,8 @@ class BlockFetcherIterator( def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) - val cmId = new ConnectionManagerId(req.address.ip, req.address.port) + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) val blockMessageArray = new BlockMessageArray(req.blocks.map { case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) }) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index f2f1e77d41..f4a2181490 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -2,6 +2,7 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import spark.Utils /** * This class represent an unique identifier for a BlockManager. @@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap */ private[spark] class BlockManagerId private ( private var executorId_ : String, - private var ip_ : String, + private var host_ : String, private var port_ : Int ) extends Externalizable { @@ -21,32 +22,45 @@ private[spark] class BlockManagerId private ( def executorId: String = executorId_ - def ip: String = ip_ + if (null != host_){ + Utils.checkHost(host_, "Expected hostname") + assert (port_ > 0) + } + + def hostPort: String = { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + host + ":" + port + } + + def host: String = host_ def port: Int = port_ override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) - out.writeUTF(ip_) + out.writeUTF(host_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() - ip_ = in.readUTF() + host_ = in.readUTF() port_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) - override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && ip == id.ip + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -55,8 +69,8 @@ private[spark] class BlockManagerId private ( private[spark] object BlockManagerId { - def apply(execId: String, ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() @@ -67,11 +81,7 @@ private[spark] object BlockManagerId { val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - if (blockManagerIdCache.containsKey(id)) { - blockManagerIdCache.get(id) - } else { - blockManagerIdCache.put(id, id) - id - } + blockManagerIdCache.putIfAbsent(id, id) + blockManagerIdCache.get(id) } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 2830bc6297..3ce1e6e257 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -332,8 +332,8 @@ object BlockManagerMasterActor { // Mapping from block id to its status. private val _blocks = new JHashMap[String, BlockStatus] - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + logInfo("Registering block manager %s with %s RAM".format( + blockManagerId.hostPort, Utils.memoryBytesToString(maxMem))) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() @@ -358,13 +358,13 @@ object BlockManagerMasterActor { _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. @@ -372,13 +372,13 @@ object BlockManagerMasterActor { _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } } diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index a25decb123..ee0c5ff9a2 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -115,6 +115,7 @@ private[spark] object BlockMessageArray { val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) newBuffer.put(buffer) buffer.rewind() }) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..c9553d2e0f 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -20,6 +20,9 @@ import spark.Utils private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { + private val mapMode = MapMode.READ_ONLY + private var mapOpenMode = "r" + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -35,7 +38,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).length() } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis val file = createFile(blockId) @@ -49,6 +55,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime))) } + private def getFileBytes(file: File): ByteBuffer = { + val length = file.length() + val channel = new RandomAccessFile(file, mapOpenMode).getChannel() + val buffer = try { + channel.map(mapMode, 0, length) + } finally { + channel.close() + } + + buffer + } + override def putValues( blockId: String, values: ArrayBuffer[Any], @@ -70,9 +88,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) if (returnValues) { // Return a byte buffer for the contents of the file - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val buffer = getFileBytes(file) PutResult(length, Right(buffer)) } else { PutResult(length, null) @@ -81,10 +97,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def getBytes(blockId: String): Option[ByteBuffer] = { val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val bytes = getFileBytes(file) Some(bytes) } @@ -96,7 +109,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (file.exists()) { file.delete() - true } else { false } @@ -175,11 +187,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) ) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") try { - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)) } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 949588476c..eba5ee507f 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes.duplicate() bytes.rewind() if (level.deserialized) { val values = blockManager.dataDeserialize(blockId, bytes) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 3b5a77ab22..cc0c354e7e 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -123,11 +123,7 @@ object StorageLevel { val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - if (storageLevelCache.containsKey(level)) { - storageLevelCache.get(level) - } else { - storageLevelCache.put(level, level) - level - } + storageLevelCache.putIfAbsent(level, level) + storageLevelCache.get(level) } } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 3e805b7831..9fb7e001ba 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService} import cc.spray.can.server.HttpServer import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler import akka.dispatch.Await -import spark.SparkException +import spark.{Utils, SparkException} import java.util.concurrent.TimeoutException /** @@ -31,7 +31,10 @@ private[spark] object AkkaUtils { val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -45,8 +48,9 @@ private[spark] object AkkaUtils { akka.remote.netty.execution-pool-size = %d akka.actor.default-dispatcher.throughput = %d akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - if (lifecycleEvents) "on" else "off")) + lifecycleEvents, akkaWriteTimeout)) val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) @@ -60,8 +64,9 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Returns the bound port or throws a SparkException on failure. + * TODO: Not changing ip to host here - is it required ? */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 188f8910da..92dfaa6e6f 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -3,6 +3,7 @@ package spark.util import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions import scala.collection.mutable.Map +import spark.scheduler.MapStatus /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion @@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { this } + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + override def -= (key: A): this.type = { internalMap.remove(key) this diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index ac51a39a51..b9b9f08810 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -2,7 +2,7 @@ @import spark.deploy.master._ @import spark.Utils -@spark.common.html.layout(title = "Spark Master on " + state.host) { +@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index c39f769a73..0e66af9284 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,7 +1,7 @@ @(worker: spark.deploy.WorkerState) @import spark.Utils -@spark.common.html.layout(title = "Spark Worker on " + worker.host) { +@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html index d54b8de4cc..cd72a688c1 100644 --- a/core/src/main/twirl/spark/storage/worker_table.scala.html +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -12,7 +12,7 @@ @for(status <- workersStatusList) { - @(status.blockManagerId.ip + ":" + status.blockManagerId.port) + @(status.blockManagerId.host + ":" + status.blockManagerId.port) @(Utils.memoryBytesToString(status.memUsed(prefix))) (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4104b33c8b..c9b4707def 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -153,7 +153,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.ip, id.port)) + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) }) diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..a3840905f4 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -18,6 +18,7 @@ class FileSuite extends FunSuite with LocalSparkContext { val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 4) nums.saveAsTextFile(outputDir) + println("outputDir = " + outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") val content = Source.fromFile(outputFile).mkString diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6da58a0f6e..c0f8986de8 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } -- cgit v1.2.3 From 54b3d45b816f26a9d3509c1f8bea70c6d99d3de0 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 15 Apr 2013 18:26:50 +0530 Subject: Checkpoint commit - compiles and passes a lot of tests - not all though, looking into FileSuite issues --- core/src/main/scala/spark/HadoopWriter.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index afcf9f6db4..80421b6328 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -24,6 +24,8 @@ import spark.SerializableWritable * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable { + + println("Created HadoopWriter") private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -41,6 +43,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe @transient private var taskContext: TaskAttemptContext = null def preSetup() { + println("preSetup") setIDs(0, 0, 0) setConfParams() @@ -50,17 +53,20 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe def setup(jobid: Int, splitid: Int, attemptid: Int) { + println("setup") setIDs(jobid, splitid, attemptid) setConfParams() } def open() { + println("open") val numfmt = NumberFormat.getInstance() numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) val outputName = "part-" + numfmt.format(splitID) val path = FileOutputFormat.getOutputPath(conf.value) + println("open outputName = " + outputName + ", fs for " + conf.value) val fs: FileSystem = { if (path != null) { path.getFileSystem(conf.value) @@ -75,6 +81,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } def write(key: AnyRef, value: AnyRef) { + println("write " + key + " = " + value) if (writer!=null) { //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") writer.write(key, value) @@ -84,16 +91,19 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } def close() { + println("close") writer.close(Reporter.NULL) } def commit() { + println("commit") val taCtxt = getTaskContext() val cmtr = getOutputCommitter() if (cmtr.needsTaskCommit(taCtxt)) { try { cmtr.commitTask(taCtxt) logInfo (taID + ": Committed") + println("Committed = " + taID) } catch { case e: IOException => { logError("Error committing the output of task: " + taID.value, e) @@ -102,11 +112,13 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } } } else { + println("No need to commit") logWarning ("No need to commit output of task: " + taID.value) } } def cleanup() { + println("cleanup") getOutputCommitter().cleanupJob(getJobContext()) } -- cgit v1.2.3 From 19652a44be81f3b8fbbb9ecc4987dcd933d2eca9 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 15 Apr 2013 19:16:36 +0530 Subject: Fix issue with FileSuite failing --- core/src/main/scala/spark/HadoopWriter.scala | 22 ++++++---------------- core/src/main/scala/spark/PairRDDFunctions.scala | 1 + core/src/test/scala/spark/FileSuite.scala | 1 - 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index 80421b6328..5e8396edb9 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -2,14 +2,10 @@ package org.apache.hadoop.mapred import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text import java.text.SimpleDateFormat import java.text.NumberFormat import java.io.IOException -import java.net.URI import java.util.Date import spark.Logging @@ -25,8 +21,6 @@ import spark.SerializableWritable */ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable { - println("Created HadoopWriter") - private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -43,7 +37,6 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe @transient private var taskContext: TaskAttemptContext = null def preSetup() { - println("preSetup") setIDs(0, 0, 0) setConfParams() @@ -53,20 +46,17 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe def setup(jobid: Int, splitid: Int, attemptid: Int) { - println("setup") setIDs(jobid, splitid, attemptid) setConfParams() } def open() { - println("open") val numfmt = NumberFormat.getInstance() numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) val outputName = "part-" + numfmt.format(splitID) val path = FileOutputFormat.getOutputPath(conf.value) - println("open outputName = " + outputName + ", fs for " + conf.value) val fs: FileSystem = { if (path != null) { path.getFileSystem(conf.value) @@ -81,7 +71,6 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } def write(key: AnyRef, value: AnyRef) { - println("write " + key + " = " + value) if (writer!=null) { //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") writer.write(key, value) @@ -91,19 +80,16 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } def close() { - println("close") writer.close(Reporter.NULL) } def commit() { - println("commit") val taCtxt = getTaskContext() val cmtr = getOutputCommitter() if (cmtr.needsTaskCommit(taCtxt)) { try { cmtr.commitTask(taCtxt) logInfo (taID + ": Committed") - println("Committed = " + taID) } catch { case e: IOException => { logError("Error committing the output of task: " + taID.value, e) @@ -112,13 +98,17 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } } } else { - println("No need to commit") logWarning ("No need to commit output of task: " + taID.value) } } + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + def cleanup() { - println("cleanup") getOutputCommitter().cleanupJob(getJobContext()) } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 39469fa3c8..9a6966b3f1 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -636,6 +636,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } self.context.runJob(self, writeToFile _) + writer.commitJob() writer.cleanup() } diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index a3840905f4..91b48c7456 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -18,7 +18,6 @@ class FileSuite extends FunSuite with LocalSparkContext { val outputDir = new File(tempDir, "output").getAbsolutePath val nums = sc.makeRDD(1 to 4) nums.saveAsTextFile(outputDir) - println("outputDir = " + outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") val content = Source.fromFile(outputFile).mkString -- cgit v1.2.3 From b42d68c8ce9f63513969297b65f4b5a2b06e6078 Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 15 Apr 2013 12:54:55 -0600 Subject: fixing Spark Streaming count() so that 0 will be emitted when there is nothing to count --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 ++++- streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e1be5ef51c..e3a9247924 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,10 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) + def count(): DStream[Long] = { + val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) + this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 8fce91853c..168e1b7a55 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -90,9 +90,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("count") { testOperation( - Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4), + Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4), (s: DStream[Int]) => s.count(), - Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L)) ) } -- cgit v1.2.3 From eb7e95e833376904bea4a9e6d1cc67c00fcfb06c Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 02:56:36 +0530 Subject: Commit job to persist files --- core/src/main/scala/spark/PairRDDFunctions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 9a6966b3f1..67fd1c1a8f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -569,6 +569,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) val count = self.context.runJob(self, writeShard _).sum + jobCommitter.commitJob(jobTaskContext) jobCommitter.cleanupJob(jobTaskContext) } -- cgit v1.2.3 From 5540ab8243a8488e30a21e1d4bb1720f1a9a555f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 02:57:43 +0530 Subject: Use hostname instead of hostport for executor, fix creation of workdir --- core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- core/src/main/scala/spark/deploy/worker/Worker.scala | 3 ++- .../main/scala/spark/executor/StandaloneExecutorBackend.scala | 11 ++++++----- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index dfcb9f0d05..04a774658e 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -70,7 +70,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTPORT}}" => hostPort + case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1 case "{{CORES}}" => cores.toString case other => other } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index cf4babc892..1a7da0f7bf 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -54,10 +54,11 @@ private[spark] class Worker( def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { - if (!workDir.exists() && !workDir.mkdirs()) { + if ( (workDir.exists() && !workDir.isDirectory) || (!workDir.exists() && !workDir.mkdirs()) ) { logError("Failed to create work directory " + workDir) System.exit(1) } + assert (workDir.isDirectory) } catch { case e: Exception => logError("Failed to create work directory " + workDir, e) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 49e1f3f07a..ebe2ac68d8 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -75,17 +75,18 @@ private[spark] object StandaloneExecutorBackend { def run0(args: Product) { assert(4 == args.productArity) runImpl(args.productElement(0).asInstanceOf[String], - args.productElement(0).asInstanceOf[String], - args.productElement(0).asInstanceOf[String], - args.productElement(0).asInstanceOf[Int]) + args.productElement(1).asInstanceOf[String], + args.productElement(2).asInstanceOf[String], + args.productElement(3).asInstanceOf[Int]) } private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) { + // Debug code + Utils.checkHost(hostname) + // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) - // Debug code - Utils.checkHost(hostname) // set it val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 6b61152ed0..0b8922d139 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend( val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTPORT}}", "{{CORES}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) -- cgit v1.2.3 From dd2b64ec97ad241b6f171cac0dbb1841b185675a Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 03:19:24 +0530 Subject: Fix bug with atomic update --- .../main/scala/spark/storage/BlockManager.scala | 44 +++++++++++++++------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 10e70723db..483b6de34b 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -527,13 +527,22 @@ class BlockManager( // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, myInfo) + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } - if (oldBlockOpt.isDefined && oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlockOpt.get.size + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } } val startTimeMs = System.currentTimeMillis @@ -638,13 +647,22 @@ class BlockManager( // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val prevInfo = blockInfo.putIfAbsent(blockId, myInfo) - if (prevInfo != null) { - // Should we check for prevInfo.waitForReady() here ? - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } } val startTimeMs = System.currentTimeMillis -- cgit v1.2.3 From 59c380d69a3831f0239b434a0fa1cf26a481d222 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 03:29:38 +0530 Subject: Fix npe --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 483b6de34b..c98ee5a0e7 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -794,7 +794,7 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } - val droppedMemorySize = memoryStore.getSize(blockId) + val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val blockWasRemoved = memoryStore.remove(blockId) if (!blockWasRemoved) { logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") -- cgit v1.2.3 From b493f55a4fe43c83061a361eef029edbac50c006 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Tue, 16 Apr 2013 10:00:33 +0800 Subject: fix a bug in netty Block Fetcher Signed-off-by: shane-huang --- .../main/java/spark/network/netty/FileServer.java | 1 - .../main/scala/spark/storage/BlockManager.scala | 69 +++++++++++----------- core/src/main/scala/spark/storage/DiskStore.scala | 4 +- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java index 729e45f0a1..38af305096 100644 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -51,7 +51,6 @@ public class FileServer { } if (bootstrap != null){ bootstrap.shutdown(); - bootstrap = null; } } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index b8b68d4283..5a00180922 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -469,21 +469,6 @@ class BlockManager( getLocal(blockId).orElse(getRemote(blockId)) } - /** - * A request to fetch one or more blocks, complete with their sizes - */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - - /** - * A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - * the block (since we want all deserializaton to happen in the calling thread); can also - * represent a fetch failure if size == -1. - */ - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } /** * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined @@ -494,9 +479,9 @@ class BlockManager( : BlockFetcherIterator = { if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ - return new NettyBlockFetcherIterator(this, blocksByAddress) + return BlockFetcherIterator("netty",this, blocksByAddress) } else { - return new BlockFetcherIterator(this, blocksByAddress) + return BlockFetcherIterator("", this, blocksByAddress) } } @@ -916,10 +901,29 @@ object BlockManager extends Logging { } } -class BlockFetcherIterator( + +trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { + def initialize +} + +object BlockFetcherIterator { + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + +class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] -) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { +) extends BlockFetcherIterator { import blockManager._ @@ -936,21 +940,9 @@ class BlockFetcherIterator( val localBlockIds = new ArrayBuffer[String]() val remoteBlockIds = new HashSet[String]() - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - // A queue to hold our results. val results = new LinkedBlockingQueue[FetchResult] - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight val fetchRequests = new Queue[FetchRequest] @@ -1072,7 +1064,6 @@ class BlockFetcherIterator( } - initialize() //an iterator that will read fetched blocks off the queue as they arrive. var resultsGotten = 0 @@ -1107,7 +1098,7 @@ class BlockFetcherIterator( class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] -) extends BlockFetcherIterator(blockManager,blocksByAddress) { +) extends BasicBlockFetcherIterator(blockManager,blocksByAddress) { import blockManager._ @@ -1129,7 +1120,7 @@ class NettyBlockFetcherIterator( } } catch { case x: InterruptedException => logInfo("Copier Interrupted") - case _ => throw new SparkException("Exception Throw in Shuffle Copier") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") } } } @@ -1232,3 +1223,13 @@ class NettyBlockFetcherIterator( } } + def apply(t: String, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]): BlockFetcherIterator = { + val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress) } + else { new BasicBlockFetcherIterator(blockManager,blocksByAddress) } + iter.initialize + iter + } + +} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index d702bb23e0..cc5bf29a32 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -39,7 +39,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() if(useNetty){ - startShuffleBlockSender() + startShuffleBlockSender() } override def getSize(blockId: String): Long = { @@ -229,7 +229,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) case e: Exception => { logError("Error running ShuffleBlockSender ", e) if (shuffleSender != null) { - shuffleSender.stop + shuffleSender.stop shuffleSender = null } } -- cgit v1.2.3 From 323ab8ff3b822af28276e1460db0f9c73d6d6409 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 17:05:10 +0530 Subject: Scala does not prevent variable shadowing ! Sick error due to it ... --- core/src/main/scala/spark/MapOutputTracker.scala | 1 - core/src/main/scala/spark/storage/BlockManager.scala | 4 ++-- pom.xml | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 6e9da02893..fde597ffd1 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -148,7 +148,6 @@ private[spark] class MapOutputTracker extends Logging { logInfo("Doing the fetch; tracker actor = " + trackerActor) val hostPort = Utils.localHostPort() // This try-finally prevents hangs due to timeouts: - var fetchedStatuses: Array[MapStatus] = null try { val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index c98ee5a0e7..6e861ac734 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -1092,7 +1092,7 @@ class BlockFetcherIterator( logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") //an iterator that will read fetched blocks off the queue as they arrive. - var resultsGotten = 0 + @volatile private var resultsGotten = 0 def hasNext: Boolean = resultsGotten < totalBlocks @@ -1102,7 +1102,7 @@ class BlockFetcherIterator( val result = results.take() val stopFetchWait = System.currentTimeMillis() _fetchWaitTime += (stopFetchWait - startFetchWait) - bytesInFlight -= result.size + if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) diff --git a/pom.xml b/pom.xml index b3134a957d..c0ba6b9b0e 100644 --- a/pom.xml +++ b/pom.xml @@ -417,8 +417,9 @@ ${project.build.directory}/surefire-reports . - WDF TestSuite.txt + ${project.build.directory}/SparkTestSuite.txt -Xms64m -Xmx1024m + -- cgit v1.2.3 From f7969f72eeaba0dc127efd13e983791f6f7930c5 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 21:51:38 +0530 Subject: Fix exception when checkpoint path does not exist (no data in rdd which is being checkpointed for example) --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 9e37bdf659..1026dc54e0 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -21,13 +21,20 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) override def getPartitions: Array[Partition] = { - val dirContents = fs.listStatus(new Path(checkpointPath)) - val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numPartitions = partitionFiles.size - if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } + val cpath = new Path(checkpointPath) + val numPartitions = + // listStatus can throw exception if path does not exist. + if (fs.exists(cpath)) { + val dirContents = fs.listStatus(cpath) + val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numPart = partitionFiles.size + if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + numPart + } else 0 + Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } @@ -64,6 +71,8 @@ private[spark] object CheckpointRDD extends Logging { val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) + println("writeToFile. path = " + path + ", tempOutputPath = " + tempOutputPath + ", finalOutputPath = " + finalOutputPath) + if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") @@ -81,8 +90,11 @@ private[spark] object CheckpointRDD extends Logging { serializeStream.writeAll(iterator) serializeStream.close() + println("writeToFile. serializeStream.close ... renaming from " + tempOutputPath + " to " + finalOutputPath) + if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " + ctx.attemptId + " and final output path does not exist") -- cgit v1.2.3 From ad80f68eb5d153d7f666447966755efce186d021 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 16 Apr 2013 22:15:34 +0530 Subject: remove spurious debug statements --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 1026dc54e0..24d527f38f 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -71,8 +71,6 @@ private[spark] object CheckpointRDD extends Logging { val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) - println("writeToFile. path = " + path + ", tempOutputPath = " + tempOutputPath + ", finalOutputPath = " + finalOutputPath) - if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") @@ -90,8 +88,6 @@ private[spark] object CheckpointRDD extends Logging { serializeStream.writeAll(iterator) serializeStream.close() - println("writeToFile. serializeStream.close ... renaming from " + tempOutputPath + " to " + finalOutputPath) - if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { logInfo("Deleting tempOutputPath " + tempOutputPath) -- cgit v1.2.3 From ab0f834dbb509d323577572691293b74368a9d86 Mon Sep 17 00:00:00 2001 From: seanm Date: Tue, 16 Apr 2013 11:57:05 -0600 Subject: adding spark.streaming.blockInterval property --- docs/configuration.md | 7 +++++++ .../main/scala/spark/streaming/dstream/NetworkInputDStream.scala | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index 04eb6daaa5..55f1962b18 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -253,6 +253,13 @@ Apart from these, the following properties are also available, and may be useful applications). Note that any RDD that persists in memory for more than this duration will be cleared as well. + + spark.streaming.blockInterval + 200 + + Duration (milliseconds) of how long to batch new objects coming from network receivers. + + diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 7385474963..26805e9621 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -198,7 +198,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log case class Block(id: String, iterator: Iterator[T], metadata: Any = null) val clock = new SystemClock() - val blockInterval = 200L + val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) -- cgit v1.2.3 From bcdde331c3ed68af27bc5d6067c78f68dbd6b032 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 04:12:18 +0530 Subject: Move from master to driver --- .../scala/spark/deploy/yarn/ApplicationMaster.scala | 20 ++++++++++---------- .../spark/deploy/yarn/YarnAllocationHandler.scala | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala index 65361e0ed9..ae719267e8 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -76,7 +76,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Start the user's JAR userThread = startUserClass() - // This a bit hacky, but we need to wait until the spark.master.port property has + // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. waitForSparkMaster() @@ -124,19 +124,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } private def waitForSparkMaster() { - logInfo("Waiting for spark master to be reachable.") - var masterUp = false - while(!masterUp) { - val masterHost = System.getProperty("spark.master.host") - val masterPort = System.getProperty("spark.master.port") + logInfo("Waiting for spark driver to be reachable.") + var driverUp = false + while(!driverUp) { + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") try { - val socket = new Socket(masterHost, masterPort.toInt) + val socket = new Socket(driverHost, driverPort.toInt) socket.close() - logInfo("Master now available: " + masterHost + ":" + masterPort) - masterUp = true + logInfo("Master now available: " + driverHost + ":" + driverPort) + driverUp = true } catch { case e: Exception => - logError("Failed to connect to master at " + masterHost + ":" + masterPort) + logError("Failed to connect to driver at " + driverHost + ":" + driverPort) Thread.sleep(100) } } diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala index cac9dab401..61dd72a651 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala @@ -191,8 +191,8 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM else { // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter) val workerId = workerIdCounter.incrementAndGet().toString - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) logInfo("launching container on " + containerId + " host " + workerHostname) @@ -209,7 +209,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM } new Thread( - new WorkerRunnable(container, conf, masterUrl, workerId, + new WorkerRunnable(container, conf, driverUrl, workerId, workerHostname, workerMemory, workerCores) ).start() } -- cgit v1.2.3 From 7e56e99573b4cf161293e648aeb159375c9c0fcb Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 24 Mar 2013 13:40:19 -0600 Subject: Surfacing decoders on KafkaInputDStream --- .../spark/streaming/examples/KafkaWordCount.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 11 ++++---- .../streaming/api/java/JavaStreamingContext.scala | 33 ++++++++++++++++------ .../streaming/dstream/KafkaInputDStream.scala | 17 ++++++----- .../test/java/spark/streaming/JavaAPISuite.java | 6 ++-- 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index 9b135a5c54..e0c3555f21 100644 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -37,7 +37,7 @@ object KafkaWordCount { ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap) + val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index bb7f216ca7..2c6326943d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import java.util.UUID import twitter4j.Status + /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic * information (such as, cluster URL and job name) to internally create a SparkContext, it provides @@ -207,14 +208,14 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ - def kafkaStream[T: ClassManifest]( + def kafkaStream( zkQuorum: String, groupId: String, topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 - ): DStream[T] = { + ): DStream[String] = { val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); - kafkaStream[T](kafkaParams, topics, storageLevel) + kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** @@ -224,12 +225,12 @@ class StreamingContext private ( * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ - def kafkaStream[T: ClassManifest]( + def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, kafkaParams, topics, storageLevel) + val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 7a8864614c..13427873ff 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -68,33 +68,50 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. */ - def kafkaStream[T]( + def kafkaStream( zkQuorum: String, groupId: String, topics: JMap[String, JInt]) - : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) } /** * Create an input stream that pulls messages form a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * @param storageLevel RDD storage level. Defaults to memory-only + * in its own thread. + */ + def kafkaStream( + zkQuorum: String, + groupId: String, + topics: JMap[String, JInt], + storageLevel: StorageLevel) + : JavaDStream[String] = { + implicit val cmt: ClassManifest[String] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + } + + /** + * Create an input stream that pulls messages form a Kafka Broker. + * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T]( + ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 17a5be3420..7bd53fb6dd 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -9,7 +9,7 @@ import java.util.concurrent.Executors import kafka.consumer._ import kafka.message.{Message, MessageSet, MessageAndMetadata} -import kafka.serializer.StringDecoder +import kafka.serializer.Decoder import kafka.utils.{Utils, ZKGroupTopicDirs} import kafka.utils.ZkUtils._ import kafka.utils.ZKStringSerializer @@ -28,7 +28,7 @@ import scala.collection.JavaConversions._ * @param storageLevel RDD storage level. */ private[streaming] -class KafkaInputDStream[T: ClassManifest]( +class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest]( @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], @@ -37,15 +37,17 @@ class KafkaInputDStream[T: ClassManifest]( def getReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(kafkaParams, topics, storageLevel) + new KafkaReceiver[T, D](kafkaParams, topics, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(kafkaParams: Map[String, String], +class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest]( + kafkaParams: Map[String, String], topics: Map[String, Int], - storageLevel: StorageLevel) extends NetworkReceiver[Any] { + storageLevel: StorageLevel + ) extends NetworkReceiver[Any] { // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) @@ -82,7 +84,8 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Create Threads for each Topic/Message Stream we are listening - val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]] + val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder) // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -91,7 +94,7 @@ class KafkaReceiver(kafkaParams: Map[String, String], } // Handles Kafka Messages - private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable { def run() { logInfo("Starting MessageHandler.") for (msgAndMetadata <- stream) { diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 3bed500f73..61e4c0a207 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -23,7 +23,6 @@ import spark.streaming.api.java.JavaPairDStream; import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; import spark.streaming.JavaCheckpointTestUtils; -import spark.streaming.dstream.KafkaPartitionKey; import spark.streaming.InputStreamsSuite; import java.io.*; @@ -1203,10 +1202,9 @@ public class JavaAPISuite implements Serializable { @Test public void testKafkaStream() { HashMap topics = Maps.newHashMap(); - HashMap offsets = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } -- cgit v1.2.3 From a402b23bcd9a9470c5fa38bf46f150b51d43eac9 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 05:52:00 +0530 Subject: Fudge order of classpath - so that our jars take precedence over what is in CLASSPATH variable. Sounds logical, hope there is no issue cos of it --- core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala | 2 +- core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala index 7fa6740579..c007dae98c 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -165,8 +165,8 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging { // If log4j present, ensure ours overrides all others if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") - Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") Client.populateHadoopClasspath(yarnConf, env) SparkHadoopUtil.setYarnMode(env) env("SPARK_YARN_JAR_PATH") = diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala index 5688f1ab66..a2bf0af762 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala @@ -153,8 +153,8 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") } - Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") Client.populateHadoopClasspath(yarnConf, env) System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } -- cgit v1.2.3 From 02dffd2eb0f5961a0e0ad93a136a086c36670b76 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 05:52:57 +0530 Subject: Ensure all ask/await block for spark.akka.askTimeout - so that it is controllable : instead of arbitrary timeouts spread across codebase. In our tests, we use 30 seconds, though default of 10 is maintained --- core/src/main/scala/spark/deploy/client/Client.scala | 3 ++- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 4 ++-- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 4 ++-- .../scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala | 1 - core/src/main/scala/spark/storage/BlockManagerMaster.scala | 2 +- core/src/main/scala/spark/storage/BlockManagerUI.scala | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 072232e33a..4af44f9c16 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.deploy._ import akka.actor._ import akka.pattern.ask +import akka.util.Duration import akka.util.duration._ import akka.pattern.AskTimeoutException import spark.{SparkException, Logging} @@ -112,7 +113,7 @@ private[spark] class Client( def stop() { if (actor != null) { try { - val timeout = 5.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") val future = actor.ask(StopClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 54faa375fb..a4e21c8130 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -3,7 +3,7 @@ package spark.deploy.master import akka.actor.{ActorRef, ActorSystem} import akka.dispatch.Await import akka.pattern.ask -import akka.util.Timeout +import akka.util.{Duration, Timeout} import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ @@ -22,7 +22,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index c834f87d50..3235c50d1b 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -3,7 +3,7 @@ package spark.deploy.worker import akka.actor.{ActorRef, ActorSystem} import akka.dispatch.Await import akka.pattern.ask -import akka.util.Timeout +import akka.util.{Duration, Timeout} import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ @@ -22,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef, workDir: File) val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index c20276a605..004592a540 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -162,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val timeout = 5.seconds val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 036fdc3480..6fae62d373 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -22,7 +22,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = 10.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") /** Remove a dead executor from the driver actor. This is only called on the driver side. */ def removeExecutor(execId: String) { diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 9e6721ec17..07da572044 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,7 +1,7 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.util.Timeout +import akka.util.Duration import akka.util.duration._ import cc.spray.typeconversion.TwirlSupport._ import cc.spray.Directives @@ -19,7 +19,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") /** Start a HTTP server to run the Web interface */ def start() { -- cgit v1.2.3 From 46779b4745dcd9cbfa6f48cd906d4a9c32fa83e2 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 05:53:28 +0530 Subject: Move back to 2.0.2-alpha, since 2.0.3-alpha is not available in cloudera yet --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c0ba6b9b0e..ecbfaf9b47 100644 --- a/pom.xml +++ b/pom.xml @@ -564,7 +564,7 @@ hadoop2-yarn 2 - 2.0.3-alpha + 2.0.2-alpha -- cgit v1.2.3 From 5d891534fd5ca268f6ba7c9a47680846eb3a15ae Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 05:54:43 +0530 Subject: Move back to 2.0.2-alpha, since 2.0.3-alpha is not available in cloudera yet. Also, add netty dependency explicitly to prevent resolving to older 2.3x version. Additionally, comment out retrievePattern to ensure correct netty is picked up --- project/SparkBuild.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f041930b4e..91e3123bc5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -20,7 +20,7 @@ object SparkBuild extends Build { //val HADOOP_YARN = false // For Hadoop 2 YARN support - val HADOOP_VERSION = "2.0.3-alpha" + val HADOOP_VERSION = "2.0.2-alpha" val HADOOP_MAJOR_VERSION = "2" val HADOOP_YARN = true @@ -47,9 +47,10 @@ object SparkBuild extends Build { scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + // retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), - testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), + // For some reason this fails on some nodes and works on others - not yet debugged why + // testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), // shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), @@ -99,6 +100,7 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( + "io.netty" % "netty" % "3.5.3.Final", "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.8" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", @@ -131,11 +133,13 @@ object SparkBuild extends Build { ), libraryDependencies ++= Seq( + "io.netty" % "netty" % "3.5.3.Final", "com.google.guava" % "guava" % "11.0.1", "log4j" % "log4j" % "1.2.16", "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "com.ning" % "compress-lzf" % "0.8.4", + "commons-daemon" % "commons-daemon" % "1.0.10", "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", -- cgit v1.2.3 From f07961060d8d9dd85ab2a581adc45f886bb0e629 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 23:13:02 +0530 Subject: Add a small note on spark.tasks.schedule.aggression --- core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 2e18d46edc..a9d9c5e44c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -47,6 +47,11 @@ private[spark] class ClusterScheduler(val sc: SparkContext) - ANY Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. + + Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether + it is left at default HOST_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. + If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact. + Also, it brings down the variance in running time drastically. */ val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL")) @@ -68,7 +73,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val activeExecutorIds = new HashSet[String] // TODO: We might want to remove this and merge it with execId datastructures - but later. - // Which hosts in the cluster are alive (contains hostPort's) + // Which hosts in the cluster are alive (contains hostPort's) - used for hyper local and local task locality. private val hostPortsAlive = new HashSet[String] private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] -- cgit v1.2.3 From 5ee2f5c4837f0098282d93c85e606e1a3af40dd6 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 17 Apr 2013 23:13:34 +0530 Subject: Cache pattern, add (commented out) alternatives for check* apis --- core/src/main/scala/spark/Utils.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 14bb153d54..3e54fa7a7e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -13,6 +13,7 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import spark.serializer.SerializerInstance import spark.deploy.SparkHadoopUtil +import java.util.regex.Pattern /** * Various utility methods used by Spark. @@ -337,9 +338,11 @@ private object Utils extends Logging { } // Used by DEBUG code : remove when all testing done + private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") def checkHost(host: String, message: String = "") { // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! - if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + if (ipPattern.matcher(host).matches()) { Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) } if (Utils.parseHostPort(host)._2 != 0){ @@ -356,6 +359,12 @@ private object Utils extends Logging { } } + // Once testing is complete in various modes, replace with this ? + /* + def checkHost(host: String, message: String = "") {} + def checkHostPort(hostPort: String, message: String = "") {} + */ + def getUserNameFromEnvironment(): String = { SparkHadoopUtil.getUserNameFromEnvironment } -- cgit v1.2.3 From e0603d7e8bfa991dfd5dc43b303c23a47aa70bca Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Thu, 18 Apr 2013 13:13:54 +0800 Subject: refactor the Schedulable interface and add unit test for SchedulingAlgorithm --- core/src/main/scala/spark/SparkContext.scala | 12 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 77 ++++---- .../cluster/FIFOTaskSetQueuesManager.scala | 49 ----- .../cluster/FairTaskSetQueuesManager.scala | 157 ---------------- .../main/scala/spark/scheduler/cluster/Pool.scala | 96 ++++++---- .../spark/scheduler/cluster/Schedulable.scala | 23 ++- .../scheduler/cluster/SchedulableBuilder.scala | 115 ++++++++++++ .../scheduler/cluster/SchedulingAlgorithm.scala | 33 ++-- .../spark/scheduler/cluster/TaskSetManager.scala | 58 ++++-- core/src/test/resources/fairscheduler.xml | 14 ++ .../spark/scheduler/ClusterSchedulerSuite.scala | 207 +++++++++++++++++++++ 11 files changed, 525 insertions(+), 316 deletions(-) delete mode 100644 core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala create mode 100644 core/src/test/resources/fairscheduler.xml create mode 100644 core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7c96ae637b..5d9a0357ad 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -146,9 +146,7 @@ class SparkContext( case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). - newInstance().asInstanceOf[TaskSetQueuesManager] - scheduler.initialize(backend, taskSetQueuesManager) + scheduler.initialize(backend) scheduler case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => @@ -167,9 +165,7 @@ class SparkContext( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). - newInstance().asInstanceOf[TaskSetQueuesManager] - scheduler.initialize(backend, taskSetQueuesManager) + scheduler.initialize(backend) backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() } @@ -188,9 +184,7 @@ class SparkContext( } else { new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) } - val taskSetQueuesManager = Class.forName(System.getProperty("spark.cluster.taskscheduler")). - newInstance().asInstanceOf[TaskSetQueuesManager] - scheduler.initialize(backend, taskSetQueuesManager) + scheduler.initialize(backend) scheduler } } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 2ddac0ff30..1a300c9e8c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -61,17 +61,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val mapOutputTracker = SparkEnv.get.mapOutputTracker - var taskSetQueuesManager: TaskSetQueuesManager = null + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null override def setListener(listener: TaskSchedulerListener) { this.listener = listener } - def initialize(context: SchedulerBackend, taskSetQueuesManager: TaskSetQueuesManager) { + def initialize(context: SchedulerBackend) { backend = context - this.taskSetQueuesManager = taskSetQueuesManager + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() } + def newTaskId(): Long = nextTaskId.getAndIncrement() override def start() { @@ -101,7 +115,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) this.synchronized { val manager = new TaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager - taskSetQueuesManager.addTaskSetManager(manager) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() if (hasReceivedTask == false) { @@ -124,26 +138,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def taskSetFinished(manager: TaskSetManager) { this.synchronized { activeTaskSets -= manager.taskSet.id - taskSetQueuesManager.removeTaskSetManager(manager) + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) } } - def taskFinished(manager: TaskSetManager) { - this.synchronized { - taskSetQueuesManager.taskFinished(manager) - } - } - /** * Called by cluster manager to offer resources on slaves. We respond by asking our active task * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { - synchronized { + synchronized { SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { @@ -155,27 +164,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray - for (i <- 0 until offers.size){ - var launchedTask = true - val execId = offers(i).executorId - val host = offers(i).hostname - while (availableCpus(i) > 0 && launchedTask){ - launchedTask = false - taskSetQueuesManager.receiveOffer(execId,host,availableCpus(i)) match { - case Some(task) => - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = task.taskSetId - taskSetTaskIds(task.taskSetId) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - - case None => {} - } + for (i <- 0 until offers.size) { + var launchedTask = true + val execId = offers(i).executorId + val host = offers(i).hostname + while (availableCpus(i) > 0 && launchedTask) { + launchedTask = false + rootPool.receiveOffer(execId,host,availableCpus(i)) match { + case Some(task) => + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = task.taskSetId + taskSetTaskIds(task.taskSetId) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + + case None => {} } + } } if (tasks.size > 0) { hasLaunchedTask = true @@ -271,7 +280,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def checkSpeculatableTasks() { var shouldRevive = false synchronized { - shouldRevive = taskSetQueuesManager.checkSpeculatableTasks() + shouldRevive = rootPool.checkSpeculatableTasks() } if (shouldRevive) { backend.reviveOffers() @@ -314,6 +323,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHost -= host } executorIdToHost -= executorId - taskSetQueuesManager.removeExecutor(executorId, host) + rootPool.executorLost(executorId, host) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala deleted file mode 100644 index 62d3130341..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/FIFOTaskSetQueuesManager.scala +++ /dev/null @@ -1,49 +0,0 @@ -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer - -import spark.Logging - -/** - * A FIFO Implementation of the TaskSetQueuesManager - */ -private[spark] class FIFOTaskSetQueuesManager extends TaskSetQueuesManager with Logging { - - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] - val tasksetSchedulingAlgorithm = new FIFOSchedulingAlgorithm() - - override def addTaskSetManager(manager: TaskSetManager) { - activeTaskSetsQueue += manager - } - - override def removeTaskSetManager(manager: TaskSetManager) { - activeTaskSetsQueue -= manager - } - - override def taskFinished(manager: TaskSetManager) { - //do nothing - } - - override def removeExecutor(executorId: String, host: String) { - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) - } - - override def receiveOffer(execId:String, host:String,avaiableCpus:Double):Option[TaskDescription] = { - - for (manager <- activeTaskSetsQueue.sortWith(tasksetSchedulingAlgorithm.comparator)) { - val task = manager.slaveOffer(execId,host,avaiableCpus) - if (task != None) { - return task - } - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - var shouldRevive = false - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } - return shouldRevive - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala deleted file mode 100644 index 89b74fbb47..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/FairTaskSetQueuesManager.scala +++ /dev/null @@ -1,157 +0,0 @@ -package spark.scheduler.cluster - -import java.io.{File, FileInputStream, FileOutputStream} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.util.control.Breaks._ -import scala.xml._ - -import spark.Logging -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** - * A Fair Implementation of the TaskSetQueuesManager - * - * Currently we support minShare,weight for fair scheduler between pools - * Within a pool, it supports FIFO or FS - * Also, currently we could allocate pools dynamically - */ -private[spark] class FairTaskSetQueuesManager extends TaskSetQueuesManager with Logging { - - val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") - val poolNameToPool= new HashMap[String, Pool] - var pools = new ArrayBuffer[Pool] - val poolScheduleAlgorithm = new FairSchedulingAlgorithm() - val POOL_FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" - val POOL_DEFAULT_POOL_NAME = "default" - val POOL_MINIMUM_SHARES_PROPERTY = "minShares" - val POOL_SCHEDULING_MODE_PROPERTY = "schedulingMode" - val POOL_WEIGHT_PROPERTY = "weight" - val POOL_POOL_NAME_PROPERTY = "@name" - val POOL_POOLS_PROPERTY = "pool" - val POOL_DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO - val POOL_DEFAULT_MINIMUM_SHARES = 2 - val POOL_DEFAULT_WEIGHT = 1 - - loadPoolProperties() - - override def addTaskSetManager(manager: TaskSetManager) { - var poolName = POOL_DEFAULT_POOL_NAME - if (manager.taskSet.properties != null) { - poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) - if (!poolNameToPool.contains(poolName)) { - //we will create a new pool that user has configured in app instead of being defined in xml file - val pool = new Pool(poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) - pools += pool - poolNameToPool(poolName) = pool - logInfo("Create pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( - poolName,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) - } - } - poolNameToPool(poolName).addTaskSetManager(manager) - logInfo("Added task set " + manager.taskSet.id + " tasks to pool "+poolName) - } - - override def removeTaskSetManager(manager: TaskSetManager) { - var poolName = POOL_DEFAULT_POOL_NAME - if (manager.taskSet.properties != null) { - poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) - } - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id,poolName)) - val pool = poolNameToPool(poolName) - pool.removeTaskSetManager(manager) - pool.runningTasks -= manager.runningTasks - } - - override def taskFinished(manager: TaskSetManager) { - var poolName = POOL_DEFAULT_POOL_NAME - if (manager.taskSet.properties != null) { - poolName = manager.taskSet.properties.getProperty(POOL_FAIR_SCHEDULER_PROPERTIES,POOL_DEFAULT_POOL_NAME) - } - val pool = poolNameToPool(poolName) - pool.runningTasks -= 1 - manager.runningTasks -=1 - } - - override def removeExecutor(executorId: String, host: String) { - for (pool <- pools) { - pool.removeExecutor(executorId,host) - } - } - - override def receiveOffer(execId: String,host:String,avaiableCpus:Double):Option[TaskDescription] = { - val sortedPools = pools.sortWith(poolScheduleAlgorithm.comparator) - for (pool <- sortedPools) { - logDebug("poolName:%s,tasksetNum:%d,minShares:%d,runningTasks:%d".format( - pool.poolName,pool.activeTaskSetsQueue.length,pool.minShare,pool.runningTasks)) - } - for (pool <- sortedPools) { - val task = pool.receiveOffer(execId,host,avaiableCpus) - if(task != None) { - pool.runningTasks += 1 - return task - } - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - var shouldRevive = false - for (pool <- pools) { - shouldRevive |= pool.checkSpeculatableTasks() - } - return shouldRevive - } - - def loadPoolProperties() { - //first check if the file exists - val file = new File(schedulerAllocFile) - if (file.exists()) { - val xml = XML.loadFile(file) - for (poolNode <- (xml \\ POOL_POOLS_PROPERTY)) { - - val poolName = (poolNode \ POOL_POOL_NAME_PROPERTY).text - var schedulingMode = POOL_DEFAULT_SCHEDULING_MODE - var minShares = POOL_DEFAULT_MINIMUM_SHARES - var weight = POOL_DEFAULT_WEIGHT - - val xmlSchedulingMode = (poolNode \ POOL_SCHEDULING_MODE_PROPERTY).text - if (xmlSchedulingMode != "") { - try{ - schedulingMode = SchedulingMode.withName(xmlSchedulingMode) - } - catch{ - case e:Exception => logInfo("Error xml schedulingMode, using default schedulingMode") - } - } - - val xmlMinShares = (poolNode \ POOL_MINIMUM_SHARES_PROPERTY).text - if (xmlMinShares != "") { - minShares = xmlMinShares.toInt - } - - val xmlWeight = (poolNode \ POOL_WEIGHT_PROPERTY).text - if (xmlWeight != "") { - weight = xmlWeight.toInt - } - - val pool = new Pool(poolName,schedulingMode,minShares,weight) - pools += pool - poolNameToPool(poolName) = pool - logInfo("Create new pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( - poolName,schedulingMode,minShares,weight)) - } - } - - if (!poolNameToPool.contains(POOL_DEFAULT_POOL_NAME)) { - val pool = new Pool(POOL_DEFAULT_POOL_NAME, POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT) - pools += pool - poolNameToPool(POOL_DEFAULT_POOL_NAME) = pool - logInfo("Create default pool with name:%s,schedulingMode:%s,minShares:%d,weight:%d".format( - POOL_DEFAULT_POOL_NAME,POOL_DEFAULT_SCHEDULING_MODE,POOL_DEFAULT_MINIMUM_SHARES,POOL_DEFAULT_WEIGHT)) - } - } - } diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index e0917ca1ca..d5482f71ad 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -1,74 +1,106 @@ package spark.scheduler.cluster import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import spark.Logging import spark.scheduler.cluster.SchedulingMode.SchedulingMode + /** - * An Schedulable entity that represent collection of TaskSetManager + * An Schedulable entity that represent collection of Pools or TaskSetManagers */ + private[spark] class Pool( val poolName: String, val schedulingMode: SchedulingMode, - initMinShare:Int, - initWeight:Int) + initMinShare: Int, + initWeight: Int) extends Schedulable with Logging { - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + var schedulableQueue = new ArrayBuffer[Schedulable] + var schedulableNameToSchedulable = new HashMap[String, Schedulable] var weight = initWeight var minShare = initMinShare var runningTasks = 0 - val priority = 0 - val stageId = 0 + var priority = 0 + var stageId = 0 + var name = poolName + var parent:Schedulable = null var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { - schedulingMode match { + schedulingMode match { case SchedulingMode.FAIR => - val schedule = new FairSchedulingAlgorithm() - schedule + new FairSchedulingAlgorithm() case SchedulingMode.FIFO => - val schedule = new FIFOSchedulingAlgorithm() - schedule + new FIFOSchedulingAlgorithm() } } - def addTaskSetManager(manager:TaskSetManager) { - activeTaskSetsQueue += manager + override def addSchedulable(schedulable: Schedulable) { + schedulableQueue += schedulable + schedulableNameToSchedulable(schedulable.name) = schedulable + schedulable.parent= this } - def removeTaskSetManager(manager:TaskSetManager) { - activeTaskSetsQueue -= manager + override def removeSchedulable(schedulable: Schedulable) { + schedulableQueue -= schedulable + schedulableNameToSchedulable -= schedulable.name } - def removeExecutor(executorId: String, host: String) { - activeTaskSetsQueue.foreach(_.executorLost(executorId,host)) + override def getSchedulableByName(schedulableName: String): Schedulable = { + if (schedulableNameToSchedulable.contains(schedulableName)) { + return schedulableNameToSchedulable(schedulableName) + } + for (schedulable <- schedulableQueue) { + var sched = schedulable.getSchedulableByName(schedulableName) + if (sched != null) { + return sched + } + } + return null } - def checkSpeculatableTasks(): Boolean = { + override def executorLost(executorId: String, host: String) { + schedulableQueue.foreach(_.executorLost(executorId, host)) + } + + override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() + for (schedulable <- schedulableQueue) { + shouldRevive |= schedulable.checkSpeculatableTasks() } return shouldRevive } - def receiveOffer(execId:String,host:String,availableCpus:Double):Option[TaskDescription] = { - val sortedActiveTasksSetQueue = activeTaskSetsQueue.sortWith(taskSetSchedulingAlgorithm.comparator) - for (manager <- sortedActiveTasksSetQueue) { - logDebug("poolname:%s,taskSetId:%s,taskNum:%d,minShares:%d,weight:%d,runningTasks:%d".format( - poolName,manager.taskSet.id,manager.numTasks,manager.minShare,manager.weight,manager.runningTasks)) + override def receiveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) + for (manager <- sortedSchedulableQueue) { + logInfo("parentName:%s,schedulableName:%s,minShares:%d,weight:%d,runningTasks:%d".format( + manager.parent.name, manager.name, manager.minShare, manager.weight, manager.runningTasks)) } - - for (manager <- sortedActiveTasksSetQueue) { - val task = manager.slaveOffer(execId,host,availableCpus) - if (task != None) { - manager.runningTasks += 1 - return task - } + for (manager <- sortedSchedulableQueue) { + val task = manager.receiveOffer(execId, host, availableCpus) + if (task != None) { + return task + } } return None } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala index 8dfc369c03..54e8ae95f9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -1,13 +1,26 @@ package spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer + /** * An interface for schedulable entities. * there are two type of Schedulable entities(Pools and TaskSetManagers) */ private[spark] trait Schedulable { - def weight:Int - def minShare:Int - def runningTasks:Int - def priority:Int - def stageId:Int + var parent: Schedulable + def weight: Int + def minShare: Int + def runningTasks: Int + def priority: Int + def stageId: Int + def name: String + + def increaseRunningTasks(taskNum: Int): Unit + def decreaseRunningTasks(taskNum: Int): Unit + def addSchedulable(schedulable: Schedulable): Unit + def removeSchedulable(schedulable: Schedulable): Unit + def getSchedulableByName(name: String): Schedulable + def executorLost(executorId: String, host: String): Unit + def receiveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] + def checkSpeculatableTasks(): Boolean } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala new file mode 100644 index 0000000000..47a426a45b --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala @@ -0,0 +1,115 @@ +package spark.scheduler.cluster + +import java.io.{File, FileInputStream, FileOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.control.Breaks._ +import scala.xml._ + +import spark.Logging +import spark.scheduler.cluster.SchedulingMode.SchedulingMode + +import java.util.Properties + +/** + * An interface to build Schedulable tree + * buildPools: build the tree nodes(pools) + * addTaskSetManager: build the leaf nodes(TaskSetManagers) + */ +private[spark] trait SchedulableBuilder { + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) +} + +private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { + + override def buildPools() { + //nothing + } + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + rootPool.addSchedulable(manager) + } +} + +private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { + + val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified") + val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" + val DEFAULT_POOL_NAME = "default" + val MINIMUM_SHARES_PROPERTY = "minShare" + val SCHEDULING_MODE_PROPERTY = "schedulingMode" + val WEIGHT_PROPERTY = "weight" + val POOL_NAME_PROPERTY = "@name" + val POOLS_PROPERTY = "pool" + val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO + val DEFAULT_MINIMUM_SHARE = 2 + val DEFAULT_WEIGHT = 1 + + override def buildPools() { + val file = new File(schedulerAllocFile) + if (file.exists()) { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ POOLS_PROPERTY)) { + + val poolName = (poolNode \ POOL_NAME_PROPERTY).text + var schedulingMode = DEFAULT_SCHEDULING_MODE + var minShare = DEFAULT_MINIMUM_SHARE + var weight = DEFAULT_WEIGHT + + val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text + if (xmlSchedulingMode != "") { + try { + schedulingMode = SchedulingMode.withName(xmlSchedulingMode) + } catch { + case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode") + } + } + + val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text + if (xmlMinShare != "") { + minShare = xmlMinShare.toInt + } + + val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text + if (xmlWeight != "") { + weight = xmlWeight.toInt + } + + val pool = new Pool(poolName, schedulingMode, minShare, weight) + rootPool.addSchedulable(pool) + logInfo("Create new pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + poolName, schedulingMode, minShare, weight)) + } + } + + //finally create "default" pool + if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { + val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(pool) + logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } +} + + override def addTaskSetManager(manager: Schedulable, properties: Properties) { + var poolName = DEFAULT_POOL_NAME + var parentPool = rootPool.getSchedulableByName(poolName) + if (properties != null) { + poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + //we will create a new pool that user has configured in app instead of being defined in xml file + parentPool = new Pool(poolName,DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logInfo("Create pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } + } + parentPool.addSchedulable(manager) + logInfo("Added task set " + manager.name + " tasks to pool "+poolName) + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index ac2237a7ef..a5d6285c99 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -2,11 +2,11 @@ package spark.scheduler.cluster /** * An interface for sort algorithm - * FIFO: FIFO algorithm for TaskSetManagers - * FS: FS algorithm for Pools, and FIFO or FS for TaskSetManagers + * FIFO: FIFO algorithm between TaskSetManagers + * FS: FS algorithm between Pools, and FIFO or FS within Pools */ private[spark] trait SchedulingAlgorithm { - def comparator(s1: Schedulable,s2: Schedulable): Boolean + def comparator(s1: Schedulable, s2: Schedulable): Boolean } private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { @@ -15,40 +15,41 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { val priority2 = s2.priority var res = Math.signum(priority1 - priority2) if (res == 0) { - val stageId1 = s1.stageId - val stageId2 = s2.stageId - res = Math.signum(stageId1 - stageId2) + val stageId1 = s1.stageId + val stageId2 = s2.stageId + res = Math.signum(stageId1 - stageId2) } - if (res < 0) + if (res < 0) { return true - else + } else { return false + } } } private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { - def comparator(s1: Schedulable, s2:Schedulable): Boolean = { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { val minShare1 = s1.minShare val minShare2 = s2.minShare val runningTasks1 = s1.runningTasks val runningTasks2 = s2.runningTasks val s1Needy = runningTasks1 < minShare1 val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / Math.max(minShare1,1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / Math.max(minShare2,1.0).toDouble + val minShareRatio1 = runningTasks1.toDouble / Math.max(minShare1, 1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / Math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true - if (s1Needy && !s2Needy) + if (s1Needy && !s2Needy) { res = true - else if(!s1Needy && s2Needy) + } else if (!s1Needy && s2Needy) { res = false - else if (s1Needy && s2Needy) + } else if (s1Needy && s2Needy) { res = minShareRatio1 <= minShareRatio2 - else + } else { res = taskToWeightRatio1 <= taskToWeightRatio2 - + } return res } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 7ec2f69da5..baaaa41a37 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -32,8 +32,6 @@ private[spark] class TaskSetManager( // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 - val TASKSET_MINIMUM_SHARES = 1 - val TASKSET_WEIGHT = 1 // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble @@ -41,12 +39,6 @@ private[spark] class TaskSetManager( // Serializer for closures and tasks. val ser = SparkEnv.get.closureSerializer.newInstance() - var weight = TASKSET_WEIGHT - var minShare = TASKSET_MINIMUM_SHARES - var runningTasks = 0 - val priority = taskSet.priority - val stageId = taskSet.stageId - val tasks = taskSet.tasks val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -55,6 +47,14 @@ private[spark] class TaskSetManager( val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksFinished = 0 + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent:Schedulable = null + // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis @@ -198,7 +198,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + override def receiveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -230,10 +230,11 @@ private[spark] class TaskSetManager( val serializedTask = Task.serializeWithDependencies( task, sched.sc.addedFiles, sched.sc.addedJars, ser) val timeTaken = System.currentTimeMillis - startTime + increaseRunningTasks(1) logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId,taskSet.id,execId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, taskSet.id, execId, taskName, serializedTask)) } case _ => } @@ -264,7 +265,7 @@ private[spark] class TaskSetManager( } val index = info.index info.markSuccessful() - sched.taskFinished(this) + decreaseRunningTasks(1) if (!finished(index)) { tasksFinished += 1 logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( @@ -293,7 +294,7 @@ private[spark] class TaskSetManager( } val index = info.index info.markFailed() - sched.taskFinished(this) + decreaseRunningTasks(1) if (!finished(index)) { logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 @@ -308,6 +309,7 @@ private[spark] class TaskSetManager( finished(index) = true tasksFinished += 1 sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) return case ef: ExceptionFailure => @@ -365,10 +367,38 @@ private[spark] class TaskSetManager( causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) sched.taskSetFinished(this) } - def executorLost(execId: String, hostname: String) { + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable:Schedulable) { + //nothing + } + + override def removeSchedulable(schedulable:Schedulable) { + //nothing + } + + override def executorLost(execId: String, hostname: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) val newHostsAlive = sched.hostsAlive // If some task has preferred locations only on hostname, and there are no more executors there, @@ -409,7 +439,7 @@ private[spark] class TaskSetManager( * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that * we don't scan the whole task set. It might also help to make this sorted by launch time. */ - def checkSpeculatableTasks(): Boolean = { + override def checkSpeculatableTasks(): Boolean = { // Can't speculate if we only have one task, or if all tasks have finished. if (numTasks == 1 || tasksFinished == numTasks) { return false diff --git a/core/src/test/resources/fairscheduler.xml b/core/src/test/resources/fairscheduler.xml new file mode 100644 index 0000000000..5a688b0ebb --- /dev/null +++ b/core/src/test/resources/fairscheduler.xml @@ -0,0 +1,14 @@ + + + 2 + 1 + FIFO + + + 3 + 1 + FIFO + + + + diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala new file mode 100644 index 0000000000..2eda48196b --- /dev/null +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -0,0 +1,207 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ + +import java.util.Properties + +class DummyTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int) + extends Schedulable { + + var parent: Schedulable = null + var weight = 1 + var minShare = 2 + var runningTasks = 0 + var priority = initPriority + var stageId = initStageId + var name = "TaskSet_"+stageId + var numTasks = initNumTasks + var tasksFinished = 0 + + def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable) { + } + + def removeSchedulable(schedulable: Schedulable) { + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + } + + def receiveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] = { + if (tasksFinished + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, stageId.toString, execId, "task 0:0", null)) + } + return None + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksFinished +=1 + if (tasksFinished == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { + + def receiveOffer(rootPool: Pool) : Option[TaskDescription] = { + rootPool.receiveOffer("execId_1", "hostname_1", 1) + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + receiveOffer(rootPool) match { + case Some(task) => + assert(task.taskSetId.toInt === expectedTaskSetId) + case _ => + } + } + + test("FIFO Scheduler Test") { + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = new DummyTaskSetManager(0, 0, 2) + val taskSetManager1 = new DummyTaskSetManager(0, 1, 2) + val taskSetManager2 = new DummyTaskSetManager(0, 2, 2) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + receiveOffer(rootPool) + checkTaskSetId(rootPool, 1) + receiveOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 2) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.cluster.fair.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.cluster.fair.pool","2") + + val taskSetManager10 = new DummyTaskSetManager(1, 0, 1) + val taskSetManager11 = new DummyTaskSetManager(1, 1, 1) + val taskSetManager12 = new DummyTaskSetManager(1, 2, 2) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = new DummyTaskSetManager(2, 3, 2) + val taskSetManager24 = new DummyTaskSetManager(2, 4, 2) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = new DummyTaskSetManager(0, 0, 5) + val taskSetManager001 = new DummyTaskSetManager(0, 1, 5) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = new DummyTaskSetManager(1, 2, 5) + val taskSetManager011 = new DummyTaskSetManager(1, 3, 5) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = new DummyTaskSetManager(2, 4, 5) + val taskSetManager101 = new DummyTaskSetManager(2, 5, 5) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = new DummyTaskSetManager(3, 6, 5) + val taskSetManager111 = new DummyTaskSetManager(3, 7, 5) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} -- cgit v1.2.3 From 8436bd5d4a96480ac1871330a28d9d712e64959d Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 19 Apr 2013 02:17:22 +0800 Subject: remove TaskSetQueueManager and update code style --- core/src/main/scala/spark/SparkContext.scala | 10 ++-------- .../main/scala/spark/scheduler/DAGSchedulerEvent.scala | 1 - .../spark/scheduler/cluster/SchedulableBuilder.scala | 14 +++++++------- .../spark/scheduler/cluster/TaskSetQueuesManager.scala | 16 ---------------- 4 files changed, 9 insertions(+), 32 deletions(-) delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5d9a0357ad..eef25ef588 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -39,7 +39,7 @@ import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import spark.scheduler._ import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler, TaskSetQueuesManager} +import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerUI import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -75,11 +75,6 @@ class SparkContext( System.setProperty("spark.driver.port", "0") } - //Set the default task scheduler - if (System.getProperty("spark.cluster.taskscheduler") == null) { - System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.FIFOTaskSetQueuesManager") - } - private val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) @@ -599,8 +594,7 @@ class SparkContext( val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler - ,localProperties.value) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 11fec568c6..303c211e2a 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -1,6 +1,5 @@ package spark.scheduler - import java.util.Properties import spark.scheduler.cluster.TaskInfo diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala index 47a426a45b..18cc15c2a5 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala @@ -86,14 +86,14 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends Schedula } } - //finally create "default" pool - if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { - val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(pool) - logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( - DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + //finally create "default" pool + if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { + val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(pool) + logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format( + DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + } } -} override def addTaskSetManager(manager: Schedulable, properties: Properties) { var poolName = DEFAULT_POOL_NAME diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala deleted file mode 100644 index 86971d47e6..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetQueuesManager.scala +++ /dev/null @@ -1,16 +0,0 @@ -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer - -/** - * An interface for managing TaskSet queue/s that allows plugging different policy for - * offering tasks to resources - */ -private[spark] trait TaskSetQueuesManager { - def addTaskSetManager(manager: TaskSetManager): Unit - def removeTaskSetManager(manager: TaskSetManager): Unit - def taskFinished(manager: TaskSetManager): Unit - def removeExecutor(executorId: String, host: String): Unit - def receiveOffer(execId: String, host:String, avaiableCpus:Double):Option[TaskDescription] - def checkSpeculatableTasks(): Boolean -} -- cgit v1.2.3 From ac2e8e8720f10efd640a67ad85270719ab2d43e9 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 19 Apr 2013 00:13:19 +0530 Subject: Add some basic documentation --- .../scala/spark/deploy/yarn/ClientArguments.scala | 6 +++-- docs/running-on-yarn.md | 31 +++++++++++++++------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala index 53b305f7df..2e69fe3fb0 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -94,9 +94,11 @@ class ClientArguments(val args: Array[String]) { " Mutliple invocations are possible, each will be passed in order.\n" + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + - " --user USERNAME Run the ApplicationMaster as a different user\n" + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n" ) System.exit(exitCode) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c2957e6cb4..26424bbe52 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -5,18 +5,25 @@ title: Launching Spark on YARN Experimental support for running over a [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) -cluster was added to Spark in version 0.6.0. Because YARN depends on version -2.0 of the Hadoop libraries, this currently requires checking out a separate -branch of Spark, called `yarn`, which you can do as follows: +cluster was added to Spark in version 0.6.0. This was merged into master as part of 0.7 effort. +To build spark core with YARN support, please use the hadoop2-yarn profile. +Ex: mvn -Phadoop2-yarn clean install - git clone git://github.com/mesos/spark - cd spark - git checkout -b yarn --track origin/yarn +# Building spark core consolidated jar. + +Currently, only sbt can buid a consolidated jar which contains the entire spark code - which is required for launching jars on yarn. +To do this via sbt - though (right now) is a manual process of enabling it in project/SparkBuild.scala. +Please comment out the + HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN +variables before the line 'For Hadoop 2 YARN support' +Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support. + +Currnetly, it is a TODO to add support for maven assembly. # Preparations -- In order to distribute Spark within the cluster, it must be packaged into a single JAR file. This can be done by running `sbt/sbt assembly` +- Building spark core assembled jar (see above). - Your application code must be packaged into a separate JAR file. If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt package`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different. @@ -30,8 +37,11 @@ The command to launch the YARN Client is as follows: --class \ --args \ --num-workers \ + --master-memory \ --worker-memory \ - --worker-cores + --worker-cores \ + --user \ + --queue For example: @@ -40,8 +50,9 @@ For example: --class spark.examples.SparkPi \ --args standalone \ --num-workers 3 \ + --master-memory 4g \ --worker-memory 2g \ - --worker-cores 2 + --worker-cores 1 The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. @@ -49,3 +60,5 @@ The above starts a YARN Client programs which periodically polls the Application - When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above. - YARN does not support requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. +- Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster). + Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed. -- cgit v1.2.3 From b2a3f24dde7a69587a5fea50d3e1e4e8f02a2dc3 Mon Sep 17 00:00:00 2001 From: koeninger Date: Sun, 21 Apr 2013 00:29:37 -0500 Subject: first attempt at an RDD to pull data from JDBC sources --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 79 +++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 core/src/main/scala/spark/rdd/JdbcRDD.scala diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala new file mode 100644 index 0000000000..c8a5d76012 --- /dev/null +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -0,0 +1,79 @@ +package spark.rdd + +import java.sql.{Connection, ResultSet} + +import spark.{Logging, Partition, RDD, SparkContext, TaskContext} +import spark.util.NextIterator + +/** + An RDD that executes an SQL query on a JDBC connection and reads results. + @param getConnection a function that returns an open Connection. + The RDD takes care of closing the connection. + @param sql the text of the query. + The query must contain two ? placeholders for parameters used to partition the results. + E.g. "select title, author from books where ? <= id and id <= ?" + @param lowerBound the minimum value of the first placeholder + @param upperBound the maximum value of the second placeholder + The lower and upper bounds are inclusive. + @param numPartitions the amount of parallelism. + Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + the query would be executed twice, once with (1, 10) and once with (11, 20) + @param mapRow a function from a ResultSet to a single row of the desired result type(s). + This should only call getInt, getString, etc; the RDD takes care of calling next. + The default maps a ResultSet to an array of Object. +*/ +class JdbcRDD[T: ClassManifest]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = + ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions). + filter(! _.isEmpty). + zipWithIndex. + map(x => new JdbcPartition(x._2, x._1.head, x._1.last)). + toArray + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + val part = thePart.asInstanceOf[JdbcPartition] + val conn = getConnection() + context.addOnCompleteCallback{ () => closeIfNeeded() } + val stmt = conn.prepareStatement(sql) + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + logInfo("closing connection") + conn.close() + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + } + +} + +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + +object JdbcRDD { + val resultSetToObjectArray = (rs: ResultSet) => + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) +} -- cgit v1.2.3 From 7acab3ab45df421601ee9a076a61de00561a0308 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 22 Apr 2013 08:01:13 +0530 Subject: Fix review comments, add a new api to SparkHadoopUtil to create appropriate Configuration. Modify an example to show how to use SplitInfo --- core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala | 5 +++++ .../hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala | 6 +++++- core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala | 5 +++++ core/src/main/scala/spark/SparkContext.scala | 14 +++++++++----- core/src/main/scala/spark/Utils.scala | 8 +++----- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 7 ++++--- examples/src/main/scala/spark/examples/SparkHdfsLR.scala | 10 ++++++++-- project/SparkBuild.scala | 10 ++++++---- 8 files changed, 45 insertions(+), 20 deletions(-) diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala index d4badbc5c4..a0fb4fe25d 100644 --- a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala @@ -1,4 +1,6 @@ package spark.deploy +import org.apache.hadoop.conf.Configuration + /** * Contains util methods to interact with Hadoop from spark. @@ -15,4 +17,7 @@ object SparkHadoopUtil { // Add support, if exists - for now, simply run func ! func(args) } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() } diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala index 66e5ad8491..ab1ab9d8a7 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala @@ -12,7 +12,7 @@ import java.security.PrivilegedExceptionAction */ object SparkHadoopUtil { - val yarnConf = new YarnConfiguration(new Configuration()) + val yarnConf = newConfiguration() def getUserNameFromEnvironment(): String = { // defaulting to env if -D is not present ... @@ -56,4 +56,8 @@ object SparkHadoopUtil { def setYarnMode(env: HashMap[String, String]) { env("SPARK_YARN_MODE") = "true" } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + // Always create a new config, dont reuse yarnConf. + def newConfiguration(): Configuration = new YarnConfiguration(new Configuration()) } diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala index d4badbc5c4..a0fb4fe25d 100644 --- a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala @@ -1,4 +1,6 @@ package spark.deploy +import org.apache.hadoop.conf.Configuration + /** * Contains util methods to interact with Hadoop from spark. @@ -15,4 +17,7 @@ object SparkHadoopUtil { // Add support, if exists - for now, simply run func ! func(args) } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index e853bce2c4..5f5ec0b0f4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} import org.apache.mesos.MesosNativeLibrary -import spark.deploy.LocalSparkCluster +import spark.deploy.{SparkHadoopUtil, LocalSparkCluster} import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} @@ -102,7 +102,9 @@ class SparkContext( // Add each JAR given through the constructor - if (jars != null) jars.foreach { addJar(_) } + if (jars != null) { + jars.foreach { addJar(_) } + } // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() @@ -114,7 +116,9 @@ class SparkContext( executorEnvs(key) = value } } - if (environment != null) executorEnvs ++= environment + if (environment != null) { + executorEnvs ++= environment + } // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -207,7 +211,7 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) @@ -711,7 +715,7 @@ class SparkContext( */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) if (!useExisting) { if (fs.exists(path)) { throw new Exception("Checkpoint directory '" + path + "' already exists.") diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 3e54fa7a7e..9f48cbe490 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -4,7 +4,6 @@ import java.io._ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ @@ -208,7 +207,7 @@ private object Utils extends Logging { case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val uri = new URI(url) - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) @@ -317,7 +316,6 @@ private object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - // customHostname.getOrElse(InetAddress.getLocalHost.getHostName) customHostname.getOrElse(localIpAddressHostname) } @@ -337,6 +335,7 @@ private object Utils extends Logging { retval } + /* // Used by DEBUG code : remove when all testing done private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") def checkHost(host: String, message: String = "") { @@ -358,12 +357,11 @@ private object Utils extends Logging { Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) } } + */ // Once testing is complete in various modes, replace with this ? - /* def checkHost(host: String, message: String = "") {} def checkHostPort(hostPort: String, message: String = "") {} - */ def getUserNameFromEnvironment(): String = { SparkHadoopUtil.getUserNameFromEnvironment diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 24d527f38f..79d00edee7 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -8,6 +8,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat +import spark.deploy.SparkHadoopUtil private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -65,7 +66,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) - val fs = outputDir.getFileSystem(new Configuration()) + val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration()) val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -103,7 +104,7 @@ private[spark] object CheckpointRDD extends Logging { } def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() @@ -125,7 +126,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala index 0f42f405a0..3d080a0257 100644 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala @@ -4,6 +4,8 @@ import java.util.Random import scala.math.exp import spark.util.Vector import spark._ +import spark.deploy.SparkHadoopUtil +import spark.scheduler.InputFormatInfo /** * Logistic regression based classification. @@ -32,9 +34,13 @@ object SparkHdfsLR { System.err.println("Usage: SparkHdfsLR ") System.exit(1) } + val inputPath = args(1) + val conf = SparkHadoopUtil.newConfiguration() val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = sc.textFile(args(1)) + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), + InputFormatInfo.computePreferredLocations( + Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) + val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).cache() val ITERATIONS = args(2).toInt diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 91e3123bc5..0a5b89d927 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -47,10 +47,8 @@ object SparkBuild extends Build { scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, - // retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), - // For some reason this fails on some nodes and works on others - not yet debugged why - // testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), + testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), // shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), @@ -170,7 +168,11 @@ object SparkBuild extends Build { Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION) }), unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / - ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") "src/hadoop2-yarn/scala" else "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala" ) + ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") { + "src/hadoop2-yarn/scala" + } else { + "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala" + } ) } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings -- cgit v1.2.3 From 0dc1e2d60f89f07f54e0985d37cdcd32ad388f6a Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 22 Apr 2013 09:22:45 -0600 Subject: Examaple of cumulative counting using updateStateByKey --- ...etworkWordCumulativeCountUpdateStateByKey.scala | 63 ++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala new file mode 100644 index 0000000000..db62246387 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala @@ -0,0 +1,63 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCumulativeCountUpdateStateByKey + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.NetworkWordCumulativeCountUpdateStateByKey local[2] localhost 9999` + */ +object NetworkWordCumulativeCountUpdateStateByKey { + private def className[A](a: A)(implicit m: Manifest[A]) = m.toString + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: NetworkWordCountUpdateStateByKey \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + val currentCount = values.foldLeft(0)(_ + _) + //println("currentCount: " + currentCount) + + val previousCount = state.getOrElse(0) + //println("previousCount: " + previousCount) + + val cumulative = Some(currentCount + previousCount) + //println("Cumulative: " + cumulative) + + cumulative + } + + // Create the context with a 10 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(10), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint(".") + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordDstream = words.map(x => (x, 1)) + + // Update the cumulative count using updateStateByKey + // This will give a Dstream made of state (which is the cumulative count of the words) + val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + + stateDstream.foreach(rdd => { + rdd.foreach(rddVal => { + println("Current Count: " + rddVal) + }) + }) + + ssc.start() + } +} -- cgit v1.2.3 From dfac0aa5c2e5f46955b008b1e8d9ee5d8069efa5 Mon Sep 17 00:00:00 2001 From: koeninger Date: Mon, 22 Apr 2013 21:12:52 -0500 Subject: prevent mysql driver from pulling entire resultset into memory. explicitly close resultset and statement. --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala index c8a5d76012..4c3054465c 100644 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -15,7 +15,7 @@ import spark.util.NextIterator @param lowerBound the minimum value of the first placeholder @param upperBound the maximum value of the second placeholder The lower and upper bounds are inclusive. - @param numPartitions the amount of parallelism. + @param numPartitions the number of partitions. Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, the query would be executed twice, once with (1, 10) and once with (11, 20) @param mapRow a function from a ResultSet to a single row of the desired result type(s). @@ -40,10 +40,15 @@ class JdbcRDD[T: ClassManifest]( toArray override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + context.addOnCompleteCallback{ () => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() - context.addOnCompleteCallback{ () => closeIfNeeded() } - val stmt = conn.prepareStatement(sql) + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + // force mysql driver to stream rather than pull entire resultset into memory + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } stmt.setLong(1, part.lower) stmt.setLong(2, part.upper) val rs = stmt.executeQuery() @@ -59,8 +64,18 @@ class JdbcRDD[T: ClassManifest]( override def close() { try { - logInfo("closing connection") - conn.close() + if (null != rs && ! rs.isClosed()) rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) stmt.close() + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn && ! stmt.isClosed()) conn.close() + logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) } -- cgit v1.2.3 From b11058f42c1c9c66ea94d3732c2efbdb57cb42b6 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 23 Apr 2013 22:48:32 +0530 Subject: Ensure that maven package adds yarn jars as part of shaded jar for hadoop2-yarn profile --- repl-bin/pom.xml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index f9d84fd3c4..b66d193b5d 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -189,17 +189,17 @@ org.apache.hadoop hadoop-client - provided + runtime org.apache.hadoop hadoop-yarn-api - provided + runtime org.apache.hadoop hadoop-yarn-common - provided + runtime -- cgit v1.2.3 From 8faf5c51c3ea0b3ad83418552b50db596fefc558 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 24 Apr 2013 02:31:57 +0530 Subject: Patch from Thomas Graves to improve the YARN Client, and move to more production ready hadoop yarn branch --- core/pom.xml | 5 ++ .../scala/spark/deploy/yarn/Client.scala | 72 +++------------------- pom.xml | 9 ++- project/SparkBuild.scala | 5 +- repl-bin/pom.xml | 5 ++ 5 files changed, 30 insertions(+), 66 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 9baa447662..7f65ce5c00 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -297,6 +297,11 @@ hadoop-yarn-common provided + + org.apache.hadoop + hadoop-yarn-client + provided + diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala index c007dae98c..7a881e26df 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -7,6 +7,7 @@ import org.apache.hadoop.net.NetUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import scala.collection.mutable.HashMap @@ -16,19 +17,19 @@ import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import spark.deploy.SparkHadoopUtil -class Client(conf: Configuration, args: ClientArguments) extends Logging { +class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { def this(args: ClientArguments) = this(new Configuration(), args) - var applicationsManager: ClientRMProtocol = null var rpc: YarnRPC = YarnRPC.create(conf) val yarnConf: YarnConfiguration = new YarnConfiguration(conf) def run() { - connectToASM() + init(yarnConf) + start() logClusterResourceDetails() - val newApp = getNewApplication() + val newApp = super.getNewApplication() val appId = newApp.getApplicationId() verifyClusterResources(newApp) @@ -47,64 +48,17 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging { System.exit(0) } - - def connectToASM() { - val rmAddress: InetSocketAddress = NetUtils.createSocketAddr( - yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS) - ) - logInfo("Connecting to ResourceManager at" + rmAddress) - applicationsManager = rpc.getProxy(classOf[ClientRMProtocol], rmAddress, conf) - .asInstanceOf[ClientRMProtocol] - } def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = getYarnClusterMetrics + val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers) -/* - val clusterNodeReports: List[NodeReport] = getNodeReports - logDebug("Got Cluster node info from ASM") - for (node <- clusterNodeReports) { - logDebug("Got node report from ASM for, nodeId=" + node.getNodeId + ", nodeAddress=" + node.getHttpAddress + - ", nodeRackName=" + node.getRackName + ", nodeNumContainers=" + node.getNumContainers + ", nodeHealthStatus=" + node.getNodeHealthStatus) - } -*/ - - val queueInfo: QueueInfo = getQueueInfo(args.amQueue) + val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) } - def getYarnClusterMetrics: YarnClusterMetrics = { - val request: GetClusterMetricsRequest = Records.newRecord(classOf[GetClusterMetricsRequest]) - val response: GetClusterMetricsResponse = applicationsManager.getClusterMetrics(request) - return response.getClusterMetrics - } - - def getNodeReports: List[NodeReport] = { - val request: GetClusterNodesRequest = Records.newRecord(classOf[GetClusterNodesRequest]) - val response: GetClusterNodesResponse = applicationsManager.getClusterNodes(request) - return response.getNodeReports.toList - } - - def getQueueInfo(queueName: String): QueueInfo = { - val request: GetQueueInfoRequest = Records.newRecord(classOf[GetQueueInfoRequest]) - request.setQueueName(queueName) - request.setIncludeApplications(true) - request.setIncludeChildQueues(false) - request.setRecursive(false) - Records.newRecord(classOf[GetQueueInfoRequest]) - return applicationsManager.getQueueInfo(request).getQueueInfo - } - - def getNewApplication(): GetNewApplicationResponse = { - logInfo("Requesting new Application") - val request = Records.newRecord(classOf[GetNewApplicationRequest]) - val response = applicationsManager.getNewApplication(request) - logInfo("Got new ApplicationId: " + response.getApplicationId()) - return response - } def verifyClusterResources(app: GetNewApplicationResponse) = { val maxMem = app.getMaximumResourceCapability().getMemory() @@ -265,23 +219,15 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging { } def submitApp(appContext: ApplicationSubmissionContext) = { - // Create the request to send to the applications manager - val appRequest = Records.newRecord(classOf[SubmitApplicationRequest]) - .asInstanceOf[SubmitApplicationRequest] - appRequest.setApplicationSubmissionContext(appContext) // Submit the application to the applications manager logInfo("Submitting application to ASM") - applicationsManager.submitApplication(appRequest) + super.submitApplication(appContext) } def monitorApplication(appId: ApplicationId): Boolean = { while(true) { Thread.sleep(1000) - val reportRequest = Records.newRecord(classOf[GetApplicationReportRequest]) - .asInstanceOf[GetApplicationReportRequest] - reportRequest.setApplicationId(appId) - val reportResponse = applicationsManager.getApplicationReport(reportRequest) - val report = reportResponse.getApplicationReport() + val report = super.getApplicationReport(appId) logInfo("Application report from ASM: \n" + "\t application identifier: " + appId.toString() + "\n" + diff --git a/pom.xml b/pom.xml index ecbfaf9b47..0e95520d50 100644 --- a/pom.xml +++ b/pom.xml @@ -564,7 +564,9 @@ hadoop2-yarn 2 - 2.0.2-alpha + + 0.23.7 + @@ -599,6 +601,11 @@ hadoop-yarn-common ${yarn.version} + + org.apache.hadoop + hadoop-yarn-client + ${yarn.version} + org.apache.avro diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0a5b89d927..819e940403 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -20,7 +20,7 @@ object SparkBuild extends Build { //val HADOOP_YARN = false // For Hadoop 2 YARN support - val HADOOP_VERSION = "2.0.2-alpha" + val HADOOP_VERSION = "0.23.7" val HADOOP_MAJOR_VERSION = "2" val HADOOP_YARN = true @@ -156,7 +156,8 @@ object SparkBuild extends Build { Seq( "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION, "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION, - "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION, + "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION ) } else { Seq( diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index b66d193b5d..46f38c2772 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -201,6 +201,11 @@ hadoop-yarn-common runtime + + org.apache.hadoop + hadoop-yarn-client + runtime + -- cgit v1.2.3 From 5b85c715c8e4241f5e07237ed62729f1a7a800a8 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 24 Apr 2013 02:57:51 +0530 Subject: Revert back to 2.0.2-alpha : 0.23.7 has protocol changes which break against cloudera --- pom.xml | 4 ++-- project/SparkBuild.scala | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 0e95520d50..12e310a038 100644 --- a/pom.xml +++ b/pom.xml @@ -565,8 +565,8 @@ 2 - 0.23.7 - + + 2.0.2-alpha diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 819e940403..0a761f1c13 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -20,7 +20,8 @@ object SparkBuild extends Build { //val HADOOP_YARN = false // For Hadoop 2 YARN support - val HADOOP_VERSION = "0.23.7" + // val HADOOP_VERSION = "0.23.7" + val HADOOP_VERSION = "2.0.2-alpha" val HADOOP_MAJOR_VERSION = "2" val HADOOP_YARN = true -- cgit v1.2.3 From 31ce6c66d6f29302d0f0f2c70e494fad0ba71e4d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 23 Apr 2013 17:48:59 -0700 Subject: Added a BlockObjectWriter interface in block manager so ShuffleMapTask doesn't need to build up an array buffer for each shuffle bucket. --- core/src/main/scala/spark/SparkEnv.scala | 18 ++++++--- .../scala/spark/scheduler/ShuffleMapTask.scala | 30 +++++++++------ .../main/scala/spark/storage/BlockException.scala | 5 +++ .../main/scala/spark/storage/BlockManager.scala | 36 ++++++++++++++---- .../scala/spark/storage/BlockObjectWriter.scala | 27 ++++++++++++++ core/src/main/scala/spark/storage/DiskStore.scala | 43 +++++++++++++++++++--- .../main/scala/spark/storage/ThreadingTest.scala | 2 +- 7 files changed, 129 insertions(+), 32 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockException.scala create mode 100644 core/src/main/scala/spark/storage/BlockObjectWriter.scala diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7157fd2688..c10bedb8f6 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,6 +22,7 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, + val shuffleSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, @@ -82,7 +83,7 @@ object SparkEnv extends Logging { } val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - + def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { logInfo("Registering " + name) @@ -96,18 +97,22 @@ object SparkEnv extends Logging { } } + val closureSerializer = instantiateClass[Serializer]( + "spark.closure.serializer", "spark.JavaSerializer") + + val shuffleSerializer = instantiateClass[Serializer]( + "spark.shuffle.serializer", "spark.JavaSerializer") + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new spark.storage.BlockManagerMasterActor(isLocal))) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) + val blockManager = new BlockManager( + executorId, actorSystem, blockManagerMaster, serializer, shuffleSerializer) val connectionManager = blockManager.connectionManager val broadcastManager = new BroadcastManager(isDriver) - val closureSerializer = instantiateClass[Serializer]( - "spark.closure.serializer", "spark.JavaSerializer") - val cacheManager = new CacheManager(blockManager) // Have to assign trackerActor after initialization as MapOutputTrackerActor @@ -144,6 +149,7 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, + shuffleSerializer, cacheManager, mapOutputTracker, shuffleFetcher, @@ -153,5 +159,5 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir) } - + } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d0..97b668cd58 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -122,27 +122,33 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) try { - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + // Obtain all the block writers for shuffle blocks. + val blockManager = SparkEnv.get.blockManager + val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId => + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId + blockManager.getBlockWriter(blockId) + } + + // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets(bucketId) += pair + buckets(bucketId).write(pair) } + // Close the bucket writers and get the sizes of each block. val compressedSizes = new Array[Byte](numOutputSplits) - - var totalBytes = 0l - - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + var i = 0 + var totalBytes = 0L + while (i < numOutputSplits) { + buckets(i).close() + val size = buckets(i).size() totalBytes += size compressedSizes(i) = MapOutputTracker.compressSize(size) + i += 1 } + + // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala new file mode 100644 index 0000000000..f275d476df --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockException.scala @@ -0,0 +1,5 @@ +package spark.storage + +private[spark] +case class BlockException(blockId: String, message: String) extends Exception(message) + diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e972..2f97bad916 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -24,16 +24,13 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -private[spark] -case class BlockException(blockId: String, message: String, ex: Exception = null) -extends Exception(message) - private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val serializer: Serializer, + val shuffleSerializer: Serializer, maxMemory: Long) extends Logging { @@ -78,7 +75,7 @@ class BlockManager( private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: BlockStore = + private[storage] val diskStore: DiskStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) val connectionManager = new ConnectionManager(0) @@ -126,8 +123,17 @@ class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + serializer: Serializer, shuffleSerializer: Serializer) = { + this(execId, actorSystem, master, serializer, shuffleSerializer, + BlockManager.getMaxMemoryFromSystemProperties) + } + + /** + * Construct a BlockManager with a memory limit set based on system properties. + */ + def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, + serializer: Serializer, maxMemory: Long) = { + this(execId, actorSystem, master, serializer, serializer, maxMemory) } /** @@ -485,6 +491,21 @@ class BlockManager( put(blockId, elements, level, tellMaster) } + /** + * A short circuited method to get a block writer that can write data directly to disk. + * This is currently used for writing shuffle files out. + */ + def getBlockWriter(blockId: String): BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId) + writer.registerCloseEventHandler(() => { + // TODO(rxin): This doesn't handle error cases. + val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) + blockInfo.put(blockId, myInfo) + myInfo.markReady(writer.size()) + }) + writer + } + /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ @@ -574,7 +595,6 @@ class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala new file mode 100644 index 0000000000..657a7e9143 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala @@ -0,0 +1,27 @@ +package spark.storage + +import java.nio.ByteBuffer + + +abstract class BlockObjectWriter(val blockId: String) { + + // TODO(rxin): What if there is an exception when the block is being written out? + + var closeEventHandler: () => Unit = _ + + def registerCloseEventHandler(handler: () => Unit) { + closeEventHandler = handler + } + + def write(value: Any) + + def writeAll(value: Iterator[Any]) { + value.foreach(write) + } + + def close() { + closeEventHandler() + } + + def size(): Long +} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad..493936fdbe 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -1,7 +1,7 @@ package spark.storage import java.nio.ByteBuffer -import java.io.{File, FileOutputStream, RandomAccessFile} +import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} import java.nio.channels.FileChannel.MapMode import java.util.{Random, Date} import java.text.SimpleDateFormat @@ -10,9 +10,9 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import scala.collection.mutable.ArrayBuffer +import spark.Utils import spark.executor.ExecutorExitCode -import spark.Utils /** * Stores BlockManager blocks on disk. @@ -20,6 +20,33 @@ import spark.Utils private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { + class DiskBlockObjectWriter(blockId: String) extends BlockObjectWriter(blockId) { + + private val f: File = createFile(blockId /*, allowAppendExisting */) + private val bs: OutputStream = blockManager.wrapForCompression(blockId, + new FastBufferedOutputStream(new FileOutputStream(f))) + private val objOut = blockManager.shuffleSerializer.newInstance().serializeStream(bs) + + private var _size: Long = -1L + + override def write(value: Any) { + objOut.writeObject(value) + } + + override def close() { + objOut.close() + bs.close() + super.close() + } + + override def size(): Long = { + if (_size < 0) { + _size = f.length() + } + _size + } + } + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -31,6 +58,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() + def getBlockWriter(blockId: String): BlockObjectWriter = { + new DiskBlockObjectWriter(blockId) + } + override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -65,8 +96,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) objOut.writeAll(values.iterator) objOut.close() val length = file.length() + + val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime))) + blockId, Utils.memoryBytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file @@ -106,9 +139,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).exists() } - private def createFile(blockId: String): File = { + private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { val file = getFile(blockId) - if (file.exists()) { + if (!allowAppendExisting && file.exists()) { throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 5c406e68cb..3875e7459e 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -78,7 +78,7 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) + "", actorSystem, blockManagerMaster, serializer, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) -- cgit v1.2.3 From adcda84f9646f12e6d5fb4f1e5e3a1b0a98b7c9f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 24 Apr 2013 08:57:25 +0530 Subject: Pull latest SparkBuild.scala from master and merge conflicts --- project/SparkBuild.scala | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0a761f1c13..0c2598ab35 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,15 +43,22 @@ object SparkBuild extends Build { def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.spark-project", - version := "0.7.1-SNAPSHOT", - scalaVersion := "2.9.2", + version := "0.8.0-SNAPSHOT", + scalaVersion := "2.9.3", scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), - // shared between both core and streaming. + // Fork new JVMs for tests and set Java options for those + fork := true, + javaOptions += "-Xmx1g", + + // Only allow one test at a time, even across projects, since they run in the same JVM + concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), + + // Shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), // For Sonatype publishing @@ -100,13 +107,12 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "io.netty" % "netty" % "3.5.3.Final", - "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", - "org.scalatest" %% "scalatest" % "1.8" % "test", - "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test", + "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", + "org.scalatest" %% "scalatest" % "1.9.1" % "test", + "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", + "com.novocode" % "junit-interface" % "0.9" % "test", "org.easymock" % "easymock" % "3.1" % "test" ), - parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }, @@ -137,8 +143,8 @@ object SparkBuild extends Build { "log4j" % "log4j" % "1.2.16", "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, - "com.ning" % "compress-lzf" % "0.8.4", "commons-daemon" % "commons-daemon" % "1.0.10", + "com.ning" % "compress-lzf" % "0.8.4", "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", @@ -149,25 +155,26 @@ object SparkBuild extends Build { "colt" % "colt" % "1.2.0", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", - "cc.spray" %% "spray-json" % "1.1.1", + "cc.spray" % "spray-json_2.9.2" % "1.1.1", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ ( if (HADOOP_MAJOR_VERSION == "2") { if (HADOOP_YARN) { Seq( - "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION, - "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION, - "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION, - "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION + // Exclude rule required for all ? + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) ) } else { Seq( - "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION, - "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION + "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) ) } } else { - Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION) + Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) ) }), unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") { @@ -189,7 +196,7 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.8") + libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From d09db1c051d255157f38f400fe9301fa438c5f41 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 24 Apr 2013 09:15:29 +0530 Subject: concurrentRestrictions fails for this PR - but works for master, probably some version change --- project/SparkBuild.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0c2598ab35..947ac47f6b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -55,9 +55,6 @@ object SparkBuild extends Build { fork := true, javaOptions += "-Xmx1g", - // Only allow one test at a time, even across projects, since they run in the same JVM - concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), - // Shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), -- cgit v1.2.3 From 3b594a4e3b94de49a09dc679a30d857e3f41df69 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 24 Apr 2013 10:18:25 +0530 Subject: Do not add signature files - results in validation errors when using assembled file --- project/SparkBuild.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0c2598ab35..b3f410bfa6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -214,6 +214,7 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard + case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } -- cgit v1.2.3 From aa618ed2a2df209da3f93a025928366959c37d04 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Apr 2013 14:52:49 -0700 Subject: Allow changing the serializer on a per shuffle basis. --- .../scala/spark/BlockStoreShuffleFetcher.scala | 13 ++++-- core/src/main/scala/spark/Dependency.scala | 4 +- core/src/main/scala/spark/PairRDDFunctions.scala | 11 ++--- core/src/main/scala/spark/ShuffleFetcher.scala | 7 ++- core/src/main/scala/spark/SparkEnv.scala | 20 ++++----- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 +++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 11 +++-- core/src/main/scala/spark/rdd/SubtractedRDD.scala | 18 +++++--- .../scala/spark/scheduler/ShuffleMapTask.scala | 8 ++-- .../main/scala/spark/serializer/Serializer.scala | 50 +++++++++++++++++++++- .../main/scala/spark/storage/BlockManager.scala | 42 +++++++++--------- .../scala/spark/storage/BlockManagerWorker.scala | 18 ++++---- core/src/main/scala/spark/storage/DiskStore.scala | 18 ++++---- .../main/scala/spark/storage/ThreadingTest.scala | 2 +- 14 files changed, 153 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index c27ed36406..2156efbd45 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -1,14 +1,19 @@ package spark -import executor.{ShuffleReadMetrics, TaskMetrics} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import spark.executor.{ShuffleReadMetrics, TaskMetrics} +import spark.serializer.Serializer import spark.storage.{DelegateBlockFetchTracker, BlockManagerId} -import util.{CompletionIterator, TimedIterator} +import spark.util.{CompletionIterator, TimedIterator} + private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { + + override def fetch[K, V]( + shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = { + logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -48,7 +53,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker itr.setDelegate(blockFetcherItr) CompletionIterator[(K,V), Iterator[(K,V)]](itr, { diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 5eea907322..2af44aa383 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * @param shuffleId the shuffle id * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output + * @param serializerClass class name of the serializer to use */ class ShuffleDependency[K, V]( @transient rdd: RDD[(K, V)], - val partitioner: Partitioner) + val partitioner: Partitioner, + val serializerClass: String = null) extends Dependency(rdd) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07efba9e8d..1b9b9d21d8 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,7 +52,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, - mapSideCombine: Boolean = true): RDD[(K, C)] = { + mapSideCombine: Boolean = true, + serializerClass: String = null): RDD[(K, C)] = { if (getKeyClass().isArray) { if (mapSideCombine) { throw new SparkException("Cannot use map-side combining with array keys.") @@ -67,13 +68,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) - val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) + val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V](self, partitioner) + val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) values.mapPartitions(aggregator.combineValuesByKey(_), true) } } @@ -469,7 +470,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Return an RDD with the pairs from `this` whose keys are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -644,7 +645,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * Return an RDD with the keys of each tuple. */ def keys: RDD[K] = self.map(_._1) - + /** * Return an RDD with the values of each tuple. */ diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index 442e9f0269..49addc0c10 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -1,13 +1,16 @@ package spark -import executor.TaskMetrics +import spark.executor.TaskMetrics +import spark.serializer.Serializer + private[spark] abstract class ShuffleFetcher { /** * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)] + def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + serializer: Serializer = Serializer.default): Iterator[(K,V)] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index c10bedb8f6..8a751fbd6e 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -3,13 +3,14 @@ package spark import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider -import serializer.Serializer import spark.broadcast.BroadcastManager import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager +import spark.serializer.Serializer import spark.util.AkkaUtils + /** * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently @@ -22,7 +23,6 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val shuffleSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, @@ -82,7 +82,11 @@ object SparkEnv extends Logging { Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") + val serializer = Serializer.setDefault( + System.getProperty("spark.serializer", "spark.JavaSerializer")) + + val closureSerializer = Serializer.get( + System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { @@ -97,17 +101,10 @@ object SparkEnv extends Logging { } } - val closureSerializer = instantiateClass[Serializer]( - "spark.closure.serializer", "spark.JavaSerializer") - - val shuffleSerializer = instantiateClass[Serializer]( - "spark.shuffle.serializer", "spark.JavaSerializer") - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new spark.storage.BlockManagerMasterActor(isLocal))) - val blockManager = new BlockManager( - executorId, actorSystem, blockManagerMaster, serializer, shuffleSerializer) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager @@ -149,7 +146,6 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - shuffleSerializer, cacheManager, mapOutputTracker, shuffleFetcher, diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index a6235491ca..9e996e9958 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} +import spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -54,7 +55,8 @@ private[spark] class CoGroupAggregator class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true) + val mapSideCombine: Boolean = true, + val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { private val aggr = new CoGroupAggregator @@ -68,9 +70,9 @@ class CoGroupedRDD[K]( logInfo("Adding shuffle dependency with " + rdd) if (mapSideCombine) { val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) + new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass) } else { - new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part) + new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass) } } } @@ -112,6 +114,7 @@ class CoGroupedRDD[K]( } } + val ser = Serializer.get(serializerClass) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent @@ -124,12 +127,12 @@ class CoGroupedRDD[K]( val fetcher = SparkEnv.get.shuffleFetcher if (mapSideCombine) { // With map side combine on, for each key, the shuffle fetcher returns a list of values. - fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, values) => getSeq(key)(depNum) ++= values } } else { // With map side combine off, for each key the shuffle fetcher returns a single value. - fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach { + fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach { case (key, value) => getSeq(key)(depNum) += value } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 4e33b7dd5c..8175e23eff 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -2,6 +2,8 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} import spark.SparkContext._ +import spark.serializer.Serializer + private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -12,13 +14,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * The resulting RDD from a shuffle (e.g. repartitioning of data). * @param prev the parent RDD. * @param part the partitioner used to partition the RDD + * @param serializerClass class name of the serializer to use. * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( @transient prev: RDD[(K, V)], - part: Partitioner) - extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { + part: Partitioner, + serializerClass: String = null) + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) { override val partitioner = Some(part) @@ -28,6 +32,7 @@ class ShuffledRDD[K, V]( override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics) + SparkEnv.get.shuffleFetcher.fetch[K, V]( + shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass)) } } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 481e03b349..f60c35c38e 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -11,6 +11,7 @@ import spark.Partition import spark.SparkEnv import spark.ShuffleDependency import spark.OneToOneDependency +import spark.serializer.Serializer /** * An optimized version of cogroup for set difference/subtraction. @@ -31,7 +32,9 @@ import spark.OneToOneDependency private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( @transient var rdd1: RDD[(K, V)], @transient var rdd2: RDD[(K, W)], - part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { + part: Partitioner, + val serializerClass: String = null) + extends RDD[(K, V)](rdd1.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { Seq(rdd1, rdd2).map { rdd => @@ -40,7 +43,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part) + new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass) } } } @@ -65,6 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] + val serializer = Serializer.get(serializerClass) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -77,12 +81,16 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM } } def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { for (t <- rdd.iterator(itsSplit, context)) op(t.asInstanceOf[(K, V)]) - case ShuffleCoGroupSplitDep(shuffleId) => - for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics)) + } + case ShuffleCoGroupSplitDep(shuffleId) => { + val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, + context.taskMetrics, serializer) + for (t <- iter) op(t.asInstanceOf[(K, V)]) + } } // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 97b668cd58..d9b26c9db9 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -13,9 +13,11 @@ import com.ning.compress.lzf.LZFInputStream import com.ning.compress.lzf.LZFOutputStream import spark._ -import executor.ShuffleWriteMetrics +import spark.executor.ShuffleWriteMetrics +import spark.serializer.Serializer import spark.storage._ -import util.{TimeStampedHashMap, MetadataCleaner} +import spark.util.{TimeStampedHashMap, MetadataCleaner} + private[spark] object ShuffleMapTask { @@ -126,7 +128,7 @@ private[spark] class ShuffleMapTask( val blockManager = SparkEnv.get.blockManager val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId => val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId - blockManager.getBlockWriter(blockId) + blockManager.getBlockWriter(blockId, Serializer.get(dep.serializerClass)) } // Write the map output to its associated buckets. diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index aca86ab6f0..77b1a1a434 100644 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -1,10 +1,14 @@ package spark.serializer -import java.nio.ByteBuffer import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap + import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + import spark.util.ByteBufferInputStream + /** * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are @@ -14,6 +18,48 @@ trait Serializer { def newInstance(): SerializerInstance } + +/** + * A singleton object that can be used to fetch serializer objects based on the serializer + * class name. If a previous instance of the serializer object has been created, the get + * method returns that instead of creating a new one. + */ +object Serializer { + + private val serializers = new ConcurrentHashMap[String, Serializer] + private var _default: Serializer = _ + + def default = _default + + def setDefault(clsName: String): Serializer = { + _default = get(clsName) + _default + } + + def get(clsName: String): Serializer = { + if (clsName == null) { + default + } else { + var serializer = serializers.get(clsName) + if (serializer != null) { + // If the serializer has been created previously, reuse that. + serializer + } else this.synchronized { + // Otherwise, create a new one. But make sure no other thread has attempted + // to create another new one at the same time. + serializer = serializers.get(clsName) + if (serializer == null) { + val clsLoader = Thread.currentThread.getContextClassLoader + serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + serializers.put(clsName, serializer) + } + serializer + } + } + } +} + + /** * An instance of a serializer, for use by one thread at a time. */ @@ -45,6 +91,7 @@ trait SerializerInstance { } } + /** * A stream for writing serialized objects. */ @@ -61,6 +108,7 @@ trait SerializationStream { } } + /** * A stream for reading serialized objects. */ diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 2f97bad916..9f7985e2e8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -29,8 +29,7 @@ class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, - val serializer: Serializer, - val shuffleSerializer: Serializer, + val defaultSerializer: Serializer, maxMemory: Long) extends Logging { @@ -123,17 +122,8 @@ class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, shuffleSerializer: Serializer) = { - this(execId, actorSystem, master, serializer, shuffleSerializer, - BlockManager.getMaxMemoryFromSystemProperties) - } - - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, maxMemory: Long) = { - this(execId, actorSystem, master, serializer, serializer, maxMemory) + serializer: Serializer) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) } /** @@ -479,9 +469,10 @@ class BlockManager( * fashion as they're received. Expects a size in bytes to be provided for each block fetched, * so that we can control the maxMegabytesInFlight for the fetch. */ - def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) + def getMultiple( + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - return new BlockFetcherIterator(this, blocksByAddress) + return new BlockFetcherIterator(this, blocksByAddress, serializer) } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -495,8 +486,8 @@ class BlockManager( * A short circuited method to get a block writer that can write data directly to disk. * This is currently used for writing shuffle files out. */ - def getBlockWriter(blockId: String): BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId) + def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId, serializer) writer.registerCloseEventHandler(() => { // TODO(rxin): This doesn't handle error cases. val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) @@ -850,7 +841,10 @@ class BlockManager( if (shouldCompress(blockId)) new LZFInputStream(s) else s } - def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = { + def dataSerialize( + blockId: String, + values: Iterator[Any], + serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() @@ -862,7 +856,10 @@ class BlockManager( * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = { + def dataDeserialize( + blockId: String, + bytes: ByteBuffer, + serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) serializer.newInstance().deserializeStream(stream).asIterator @@ -916,7 +913,8 @@ object BlockManager extends Logging { class BlockFetcherIterator( private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer ) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { import blockManager._ @@ -979,8 +977,8 @@ class BlockFetcherIterator( "Unexpected message " + blockMessage.getType + " received from " + cmId) } val blockId = blockMessage.getId - results.put(new FetchResult( - blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData))) + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) _remoteBytesRead += req.size logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d2985559c1..15225f93a6 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -19,7 +19,7 @@ import spark.network._ */ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { initLogging() - + blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { @@ -51,7 +51,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) return None - } + } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) logDebug("Received [" + gB + "]") @@ -90,28 +90,26 @@ private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 - + initLogging() - + def startBlockManagerWorker(manager: BlockManager) { blockManagerWorker = new BlockManagerWorker(manager) } - + def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val resultMessage = connectionManager.sendMessageReliablySync( toConnManagerId, blockMessageArray.toBufferMessage) return (resultMessage != None) } - + def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer + val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) val responseMessage = connectionManager.sendMessageReliablySync( diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 493936fdbe..70ad887c3b 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -1,17 +1,18 @@ package spark.storage -import java.nio.ByteBuffer import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} +import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode import java.util.{Random, Date} import java.text.SimpleDateFormat -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - import scala.collection.mutable.ArrayBuffer +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + import spark.Utils import spark.executor.ExecutorExitCode +import spark.serializer.Serializer /** @@ -20,12 +21,13 @@ import spark.executor.ExecutorExitCode private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { - class DiskBlockObjectWriter(blockId: String) extends BlockObjectWriter(blockId) { + class DiskBlockObjectWriter(blockId: String, serializer: Serializer) + extends BlockObjectWriter(blockId) { private val f: File = createFile(blockId /*, allowAppendExisting */) private val bs: OutputStream = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(new FileOutputStream(f))) - private val objOut = blockManager.shuffleSerializer.newInstance().serializeStream(bs) + private val objOut = serializer.newInstance().serializeStream(bs) private var _size: Long = -1L @@ -58,8 +60,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() - def getBlockWriter(blockId: String): BlockObjectWriter = { - new DiskBlockObjectWriter(blockId) + def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { + new DiskBlockObjectWriter(blockId, serializer) } override def getSize(blockId: String): Long = { @@ -92,7 +94,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = createFile(blockId) val fileOut = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.serializer.newInstance().serializeStream(fileOut) + val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) objOut.writeAll(values.iterator) objOut.close() val length = file.length() diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 3875e7459e..5c406e68cb 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -78,7 +78,7 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, serializer, 1024 * 1024) + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) -- cgit v1.2.3 From ba6ffa6a5f39765e1652735d1b16b54c2fc78674 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Apr 2013 17:38:07 -0700 Subject: Allow the specification of a shuffle serializer in the read path (for local block reads). --- .../scala/spark/scheduler/ShuffleMapTask.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 29 +++++++++++----------- core/src/main/scala/spark/storage/DiskStore.scala | 8 ++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index d9b26c9db9..826f14d658 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -128,7 +128,7 @@ private[spark] class ShuffleMapTask( val blockManager = SparkEnv.get.blockManager val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId => val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId - blockManager.getBlockWriter(blockId, Serializer.get(dep.serializerClass)) + blockManager.getDiskBlockWriter(blockId, Serializer.get(dep.serializerClass)) } // Write the map output to its associated buckets. diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 9f7985e2e8..fa02dd54b8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -268,23 +268,24 @@ class BlockManager( return locations } + /** + * A short-circuited method to get blocks directly from disk. This is used for getting + * shuffle blocks. It is safe to do so without a lock on block info since disk store + * never deletes (recent) items. + */ + def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + diskStore.getValues(blockId, serializer) match { + case Some(iterator) => Some(iterator) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + /** * Get block from local block manager. */ def getLocal(blockId: String): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - - // As an optimization for map output fetches, if the block is for a shuffle, return it - // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.startsWith("shuffle_")) { - return diskStore.getValues(blockId) match { - case Some(iterator) => - Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { @@ -486,7 +487,7 @@ class BlockManager( * A short circuited method to get a block writer that can write data directly to disk. * This is currently used for writing shuffle files out. */ - def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { + def getDiskBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { val writer = diskStore.getBlockWriter(blockId, serializer) writer.registerCloseEventHandler(() => { // TODO(rxin): This doesn't handle error cases. @@ -1042,7 +1043,7 @@ class BlockFetcherIterator( // any memory that might exceed our maxBytesInFlight startTime = System.currentTimeMillis for (id <- localBlockIds) { - getLocal(id) match { + getLocalFromDisk(id, serializer) match { case Some(iter) => { results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight logDebug("Got local block " + id) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 70ad887c3b..7f512b162a 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -127,6 +127,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } + /** + * A version of getValues that allows a custom serializer. This is used as part of the + * shuffle short-circuit code. + */ + def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) + } + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { -- cgit v1.2.3 From 01d9ba503878d4191eaa8080e86c631d3c705cce Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 25 Apr 2013 00:11:27 -0700 Subject: Add back line removed during YARN merge --- project/SparkBuild.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b3f410bfa6..44758ad87e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -48,6 +48,7 @@ object SparkBuild extends Build { scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), -- cgit v1.2.3 From eef9ea1993270d5f07e52e807e8d149e54079aad Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 24 Apr 2013 15:08:20 -0700 Subject: Update unit test memory to 2 GB --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 44758ad87e..f32c47e71f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -54,7 +54,7 @@ object SparkBuild extends Build { // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions += "-Xmx1g", + javaOptions += "-Xmx2g", // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), -- cgit v1.2.3 From 6e6b5204ea015fc7cc2c3e16e0032be3074413be Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 24 Apr 2013 18:53:12 -0700 Subject: Create an empty directory when checkpointing a 0-partition RDD (fixes a test failure on Hadoop 2.0) --- core/src/main/scala/spark/RDDCheckpointData.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index d00092e984..57e0405fb4 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,6 +1,7 @@ package spark import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import rdd.{CheckpointRDD, CoalescedRDD} import scheduler.{ResultTask, ShuffleMapTask} @@ -62,14 +63,20 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) } } + // Create the output path for the checkpoint + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val fs = path.getFileSystem(new Configuration()) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } + // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) - val newRDD = new CheckpointRDD[T](rdd.context, path) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) + val newRDD = new CheckpointRDD[T](rdd.context, path.toString) // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { - cpFile = Some(path) + cpFile = Some(path.toString) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed -- cgit v1.2.3 From a72134a6ac04e2e49679bbd5ba1266daf909bec8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 25 Apr 2013 10:39:28 -0700 Subject: SPARK-739 Have quickstart standlone job use README --- docs/quick-start.md | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/docs/quick-start.md b/docs/quick-start.md index 2d961b29cb..335643536a 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -113,8 +113,8 @@ import SparkContext._ object SimpleJob { def main(args: Array[String]) { - val logFile = "/var/log/syslog" // Should be some file on your system - val sc = new SparkContext("local", "Simple Job", "$YOUR_SPARK_HOME", + val logFile = "$YOUR_SPARK_HOME/README.md" // Should be some file on your system + val sc = new SparkContext("local", "Simple Job", "YOUR_SPARK_HOME", List("target/scala-{{site.SCALA_VERSION}}/simple-project_{{site.SCALA_VERSION}}-1.0.jar")) val logData = sc.textFile(logFile, 2).cache() val numAs = logData.filter(line => line.contains("a")).count() @@ -124,7 +124,7 @@ object SimpleJob { } {% endhighlight %} -This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes. +This job simply counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes. This file depends on the Spark API, so we'll also include an sbt configuration file, `simple.sbt` which explains that Spark is a dependency. This file also adds two repositories which host Spark dependencies: @@ -156,7 +156,7 @@ $ find . $ sbt package $ sbt run ... -Lines with a: 8422, Lines with b: 1836 +Lines with a: 46, Lines with b: 23 {% endhighlight %} This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. @@ -173,7 +173,7 @@ import spark.api.java.function.Function; public class SimpleJob { public static void main(String[] args) { - String logFile = "/var/log/syslog"; // Should be some file on your system + String logFile = "$YOUR_SPARK_HOME/README.md"; // Should be some file on your system JavaSparkContext sc = new JavaSparkContext("local", "Simple Job", "$YOUR_SPARK_HOME", new String[]{"target/simple-project-1.0.jar"}); JavaRDD logData = sc.textFile(logFile).cache(); @@ -191,7 +191,7 @@ public class SimpleJob { } {% endhighlight %} -This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that like in the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail. +This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. As with the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail. To build the job, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version. @@ -239,7 +239,7 @@ Now, we can execute the job using Maven: $ mvn package $ mvn exec:java -Dexec.mainClass="SimpleJob" ... -Lines with a: 8422, Lines with b: 1836 +Lines with a: 46, Lines with b: 23 {% endhighlight %} This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. @@ -253,7 +253,7 @@ As an example, we'll create a simple Spark job, `SimpleJob.py`: """SimpleJob.py""" from pyspark import SparkContext -logFile = "/var/log/syslog" # Should be some file on your system +logFile = "$YOUR_SPARK_HOME/README.md" # Should be some file on your system sc = SparkContext("local", "Simple job") logData = sc.textFile(logFile).cache() @@ -265,7 +265,8 @@ print "Lines with a: %i, lines with b: %i" % (numAs, numBs) This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. -Like in the Scala and Java examples, we use a SparkContext to create RDDs. +Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. +As with the Scala and Java examples, we use a SparkContext to create RDDs. We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference. For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide.html). `SimpleJob` is simple enough that we do not need to specify any code dependencies. @@ -276,7 +277,7 @@ We can run this job using the `pyspark` script: $ cd $SPARK_HOME $ ./pyspark SimpleJob.py ... -Lines with a: 8422, Lines with b: 1836 +Lines with a: 46, Lines with b: 23 {% endhighlight python %} This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. -- cgit v1.2.3 From 1b169f190c5c5210d088faced86dee1007295ac8 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 25 Apr 2013 19:52:12 -0700 Subject: Exclude old versions of Netty, which had a different Maven organization --- project/SparkBuild.scala | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f32c47e71f..7bd6c4c235 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -129,6 +129,9 @@ object SparkBuild extends Build { val slf4jVersion = "1.6.1" + val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") + val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + def coreSettings = sharedSettings ++ Seq( name := "spark-core", resolvers ++= Seq( @@ -149,33 +152,33 @@ object SparkBuild extends Build { "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", - "com.typesafe.akka" % "akka-actor" % "2.0.3", - "com.typesafe.akka" % "akka-remote" % "2.0.3", - "com.typesafe.akka" % "akka-slf4j" % "2.0.3", + "com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty), + "com.typesafe.akka" % "akka-remote" % "2.0.3" excludeAll(excludeNetty), + "com.typesafe.akka" % "akka-slf4j" % "2.0.3" excludeAll(excludeNetty), "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", - "cc.spray" % "spray-can" % "1.0-M2.1", - "cc.spray" % "spray-server" % "1.0-M2.1", - "cc.spray" % "spray-json_2.9.2" % "1.1.1", + "cc.spray" % "spray-can" % "1.0-M2.1" excludeAll(excludeNetty), + "cc.spray" % "spray-server" % "1.0-M2.1" excludeAll(excludeNetty), + "cc.spray" % "spray-json_2.9.2" % "1.1.1" excludeAll(excludeNetty), "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ ( if (HADOOP_MAJOR_VERSION == "2") { if (HADOOP_YARN) { Seq( // Exclude rule required for all ? - "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), - "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), - "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), - "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty), + "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty), + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty), + "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty) ) } else { Seq( - "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), - "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) + "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty), + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty) ) } } else { - Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) ) + Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty) ) }), unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") { @@ -205,10 +208,10 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty), "com.github.sgroschupf" % "zkclient" % "0.1", - "org.twitter4j" % "twitter4j-stream" % "3.0.3", - "com.typesafe.akka" % "akka-zeromq" % "2.0.3" + "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), + "com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty) ) ) ++ assemblySettings ++ extraAssemblySettings -- cgit v1.2.3 From c9c4954d994c5ba824e71c1c5cd8d5de531caf78 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 26 Apr 2013 16:57:46 -0700 Subject: Add an interface to zip iterators of multiple RDDs The current code supports 2, 3 or 4 arguments but can be extended to more arguments if required. --- core/src/main/scala/spark/RDD.scala | 22 ++++ .../scala/spark/rdd/MapZippedPartitionsRDD.scala | 118 +++++++++++++++++++++ .../scala/spark/MapZippedPartitionsSuite.scala | 34 ++++++ 3 files changed, 174 insertions(+) create mode 100644 core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala create mode 100644 core/src/test/scala/spark/MapZippedPartitionsSuite.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ccd9d0364a..8e7e1457c1 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,6 +35,9 @@ import spark.rdd.ShuffledRDD import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD +import spark.rdd.MapZippedPartitionsRDD2 +import spark.rdd.MapZippedPartitionsRDD3 +import spark.rdd.MapZippedPartitionsRDD4 import spark.storage.StorageLevel import SparkContext._ @@ -436,6 +439,25 @@ abstract class RDD[T: ClassManifest]( */ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + def zipAndMapPartitions[B: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B]) => Iterator[V], + rdd2: RDD[B]) = + new MapZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + + def zipAndMapPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C]) = + new MapZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + + def zipAndMapPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C], + rdd4: RDD[D]) = + new MapZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala new file mode 100644 index 0000000000..6653b3b444 --- /dev/null +++ b/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala @@ -0,0 +1,118 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import java.io.{ObjectOutputStream, IOException} + +private[spark] class MapZippedPartition( + idx: Int, + @transient rdds: Seq[RDD[_]] + ) extends Partition { + + override val index: Int = idx + var partitionValues = rdds.map(rdd => rdd.partitions(idx)) + def partitions = partitionValues + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + oos.defaultWriteObject() + } +} + +abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[_]]) + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + + override def getPartitions: Array[Partition] = { + val sizes = rdds.map(x => x.partitions.size) + if (!sizes.forall(x => x == sizes(0))) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](sizes(0)) + for (i <- 0 until sizes(0)) { + array(i) = new MapZippedPartition(i, rdds) + } + array + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + val splits = s.asInstanceOf[MapZippedPartition].partitions + val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) + preferredLocations.reduce((x, y) => x.intersect(y)) + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} + +class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} + +class MapZippedPartitionsRDD3[A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + } +} + +class MapZippedPartitionsRDD4[A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C], + var rdd4: RDD[D]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context), + rdd4.iterator(partitions(3), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + rdd4 = null + } +} diff --git a/core/src/test/scala/spark/MapZippedPartitionsSuite.scala b/core/src/test/scala/spark/MapZippedPartitionsSuite.scala new file mode 100644 index 0000000000..f65a646416 --- /dev/null +++ b/core/src/test/scala/spark/MapZippedPartitionsSuite.scala @@ -0,0 +1,34 @@ +package spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object MapZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class MapZippedPartitionsSuite extends FunSuite with LocalSparkContext { + test("print sizes") { + sc = new SparkContext("local", "test") + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipAndMapPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +} -- cgit v1.2.3 From 0cc6642b7c6fbb4167956b668603f2ea6fb5ac8e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 28 Apr 2013 05:11:03 -0700 Subject: Rename to zipPartitions and style changes --- core/src/main/scala/spark/RDD.scala | 24 +++++++++++----------- .../scala/spark/rdd/MapZippedPartitionsRDD.scala | 22 +++++++++++--------- .../scala/spark/MapZippedPartitionsSuite.scala | 2 +- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8e7e1457c1..bded55238f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -439,22 +439,22 @@ abstract class RDD[T: ClassManifest]( */ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) - def zipAndMapPartitions[B: ClassManifest, V: ClassManifest]( - f: (Iterator[T], Iterator[B]) => Iterator[V], - rdd2: RDD[B]) = + def zipPartitions[B: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B]) => Iterator[V], + rdd2: RDD[B]) = new MapZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) - def zipAndMapPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( - f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], - rdd2: RDD[B], - rdd3: RDD[C]) = + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C]) = new MapZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) - def zipAndMapPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( - f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], - rdd2: RDD[B], - rdd3: RDD[C], - rdd4: RDD[D]) = + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C], + rdd4: RDD[D]) = new MapZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) diff --git a/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala index 6653b3b444..3520fd24b0 100644 --- a/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala @@ -4,13 +4,13 @@ import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} private[spark] class MapZippedPartition( - idx: Int, - @transient rdds: Seq[RDD[_]] - ) extends Partition { + idx: Int, + @transient rdds: Seq[RDD[_]]) + extends Partition { override val index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues + def partitions = partitionValues @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { @@ -68,7 +68,8 @@ class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManife } } -class MapZippedPartitionsRDD3[A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( +class MapZippedPartitionsRDD3 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], @@ -78,8 +79,8 @@ class MapZippedPartitionsRDD3[A: ClassManifest, B: ClassManifest, C: ClassManife override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[MapZippedPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context)) } @@ -91,7 +92,8 @@ class MapZippedPartitionsRDD3[A: ClassManifest, B: ClassManifest, C: ClassManife } } -class MapZippedPartitionsRDD4[A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( +class MapZippedPartitionsRDD4 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], @@ -102,8 +104,8 @@ class MapZippedPartitionsRDD4[A: ClassManifest, B: ClassManifest, C: ClassManife override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[MapZippedPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context), rdd4.iterator(partitions(3), context)) } diff --git a/core/src/test/scala/spark/MapZippedPartitionsSuite.scala b/core/src/test/scala/spark/MapZippedPartitionsSuite.scala index f65a646416..834b517cbc 100644 --- a/core/src/test/scala/spark/MapZippedPartitionsSuite.scala +++ b/core/src/test/scala/spark/MapZippedPartitionsSuite.scala @@ -24,7 +24,7 @@ class MapZippedPartitionsSuite extends FunSuite with LocalSparkContext { val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) val data3 = sc.makeRDD(Array(1.0, 2.0), 2) - val zippedRDD = data1.zipAndMapPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) + val zippedRDD = data1.zipPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) val obtainedSizes = zippedRDD.collect() val expectedSizes = Array(2, 3, 1, 2, 3, 1) -- cgit v1.2.3 From afee9024430ef79cc0840a5e5788b60c8c53f9d2 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 28 Apr 2013 22:26:45 +0530 Subject: Attempt to fix streaming test failures after yarn branch merge --- bagel/src/test/scala/bagel/BagelSuite.scala | 1 + core/src/test/scala/spark/LocalSparkContext.scala | 3 ++- repl/src/test/scala/spark/repl/ReplSuite.scala | 1 + .../main/scala/spark/streaming/Checkpoint.scala | 30 +++++++++++++++++----- .../spark/streaming/util/MasterFailureTest.scala | 8 +++++- .../spark/streaming/BasicOperationsSuite.scala | 1 + .../scala/spark/streaming/CheckpointSuite.scala | 4 ++- .../test/scala/spark/streaming/FailureSuite.scala | 2 ++ .../scala/spark/streaming/InputStreamsSuite.scala | 1 + .../spark/streaming/WindowOperationsSuite.scala | 1 + 10 files changed, 42 insertions(+), 10 deletions(-) diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 25db395c22..a09c978068 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -23,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } test("halting by voting") { diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index ff00dd05dd..76d5258b02 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -27,6 +27,7 @@ object LocalSparkContext { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ @@ -38,4 +39,4 @@ object LocalSparkContext { } } -} \ No newline at end of file +} diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index 43559b96d3..1c64f9b98d 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -32,6 +32,7 @@ class ReplSuite extends FunSuite { interp.sparkContext.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") return out.toString } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index e303e33e5e..7bd104b8d5 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -38,28 +38,43 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) private[streaming] class CheckpointWriter(checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") + // The file to which we actually write - and then "move" to file. + private val writeFile = new Path(file.getParent, file.getName + ".next") + private val bakFile = new Path(file.getParent, file.getName + ".bk") + + @volatile private var stopped = false + val conf = new Configuration() var fs = file.getFileSystem(conf) val maxAttempts = 3 val executor = Executors.newFixedThreadPool(1) + // Removed code which validates whether there is only one CheckpointWriter per path 'file' since + // I did not notice any errors - reintroduce it ? + class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() while (attempts < maxAttempts) { + if (stopped) { + logInfo("Already stopped, ignore checkpoint attempt for " + file) + return + } attempts += 1 try { logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) - } - val fos = fs.create(file) + // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast. + val fos = fs.create(writeFile) fos.write(bytes) fos.close() - fos.close() + if (fs.exists(file) && fs.rename(file, bakFile)) { + logDebug("Moved existing checkpoint file to " + bakFile) + } + // paranoia + fs.delete(file, false) + fs.rename(writeFile, file) + val finishTime = System.currentTimeMillis(); logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") @@ -84,6 +99,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } def stop() { + stopped = true executor.shutdown() } } diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index f673e5be15..e7a3f92bc0 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -74,6 +74,7 @@ object MasterFailureTest extends Logging { val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { + logInfo("UpdateFunc .. state = " + state.getOrElse(0L) + ", values = " + values) Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) } st.flatMap(_.split(" ")) @@ -159,6 +160,7 @@ object MasterFailureTest extends Logging { // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) @@ -205,6 +207,7 @@ object MasterFailureTest extends Logging { // (iii) Its not timed out yet System.clearProperty("spark.streaming.clock") System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") ssc.start() val startTime = System.currentTimeMillis() while (!killed && !isLastOutputGenerated && !isTimedOut) { @@ -357,13 +360,16 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) // Write the data to a local file and then move it to the target test directory val localFile = new File(localTestDir, (i+1).toString) val hadoopFile = new Path(testDir, (i+1).toString) + val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString) FileUtils.writeStringToFile(localFile, input(i).toString + "\n") var tries = 0 var done = false while (!done && tries < maxTries) { tries += 1 try { - fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile) + fs.rename(tempHadoopFile, hadoopFile) done = true } catch { case ioe: IOException => { diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index cf2ed8b1d4..e7352deb81 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -15,6 +15,7 @@ class BasicOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } test("map") { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index cac86deeaf..607dea77ec 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -31,6 +31,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } var ssc: StreamingContext = null @@ -325,6 +326,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ) ssc = new StreamingContext(checkpointDir) System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") ssc.start() val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) // the first element will be re-processed data of the last batch before restart @@ -350,4 +352,4 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] outputStream.output } -} \ No newline at end of file +} diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index a5fa7ab92d..4529e774e9 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,10 +22,12 @@ class FailureSuite extends FunSuite with BeforeAndAfter with Logging { val batchDuration = Milliseconds(1000) before { + logInfo("BEFORE ...") FileUtils.deleteDirectory(new File(directory)) } after { + logInfo("AFTER ...") FileUtils.deleteDirectory(new File(directory)) } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 67dca2ac31..0acb6db6f2 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -41,6 +41,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 1b66f3bda2..80d827706f 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -16,6 +16,7 @@ class WindowOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") } val largerSlideInput = Seq( -- cgit v1.2.3 From 7fa6978a1e8822cf377fbb1e8a8d23adc4ebe12e Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 28 Apr 2013 23:08:10 +0530 Subject: Allow CheckpointWriter pending tasks to finish --- streaming/src/main/scala/spark/streaming/Checkpoint.scala | 13 +++++++------ streaming/src/main/scala/spark/streaming/DStreamGraph.scala | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 7bd104b8d5..4bbad908d0 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -42,7 +42,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { private val writeFile = new Path(file.getParent, file.getName + ".next") private val bakFile = new Path(file.getParent, file.getName + ".bk") - @volatile private var stopped = false + private var stopped = false val conf = new Configuration() var fs = file.getFileSystem(conf) @@ -57,10 +57,6 @@ class CheckpointWriter(checkpointDir: String) extends Logging { var attempts = 0 val startTime = System.currentTimeMillis() while (attempts < maxAttempts) { - if (stopped) { - logInfo("Already stopped, ignore checkpoint attempt for " + file) - return - } attempts += 1 try { logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") @@ -99,8 +95,13 @@ class CheckpointWriter(checkpointDir: String) extends Logging { } def stop() { - stopped = true + synchronized { + if (stopped) return ; + stopped = true + } executor.shutdown() + val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + logInfo("CheckpointWriter executor terminated ? " + terminated) } } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index adb7f3a24d..3b331956f5 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -54,8 +54,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { throw new Exception("Batch duration already set as " + batchDuration + ". cannot set it again.") } + batchDuration = duration } - batchDuration = duration } def remember(duration: Duration) { -- cgit v1.2.3 From 9bd439502e371e1ff9d6184c7182bc414104e39e Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 28 Apr 2013 23:09:08 +0530 Subject: Remove spurious commit --- project/SparkBuild.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7c004df6fb..7bd6c4c235 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -56,6 +56,9 @@ object SparkBuild extends Build { fork := true, javaOptions += "-Xmx2g", + // Only allow one test at a time, even across projects, since they run in the same JVM + concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), + // Shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), -- cgit v1.2.3 From 3a89a76b874298853cf47510ab33e863abf117d7 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 29 Apr 2013 00:04:12 +0530 Subject: Make log message more descriptive to aid in debugging --- streaming/src/main/scala/spark/streaming/Checkpoint.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 4bbad908d0..66e67cbfa1 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -100,8 +100,10 @@ class CheckpointWriter(checkpointDir: String) extends Logging { stopped = true } executor.shutdown() + val startTime = System.currentTimeMillis() val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) - logInfo("CheckpointWriter executor terminated ? " + terminated) + val endTime = System.currentTimeMillis() + logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") } } -- cgit v1.2.3 From 430c531464a5372237c97394f8f4b6ec344394c0 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 29 Apr 2013 00:24:30 +0530 Subject: Remove debug statements --- streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala | 1 - streaming/src/test/scala/spark/streaming/FailureSuite.scala | 2 -- 2 files changed, 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index e7a3f92bc0..426a9b6f71 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -74,7 +74,6 @@ object MasterFailureTest extends Logging { val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { - logInfo("UpdateFunc .. state = " + state.getOrElse(0L) + ", values = " + values) Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) } st.flatMap(_.split(" ")) diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 4529e774e9..a5fa7ab92d 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,12 +22,10 @@ class FailureSuite extends FunSuite with BeforeAndAfter with Logging { val batchDuration = Milliseconds(1000) before { - logInfo("BEFORE ...") FileUtils.deleteDirectory(new File(directory)) } after { - logInfo("AFTER ...") FileUtils.deleteDirectory(new File(directory)) } -- cgit v1.2.3 From 6e84635ab904ee2798f1d6acd3a8ed5e01563e54 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 28 Apr 2013 15:58:40 -0700 Subject: Rename classes from MapZipped* to Zipped* --- .../scala/spark/rdd/MapZippedPartitionsRDD.scala | 120 --------------------- .../main/scala/spark/rdd/ZippedPartitionsRDD.scala | 120 +++++++++++++++++++++ .../scala/spark/MapZippedPartitionsSuite.scala | 34 ------ .../test/scala/spark/ZippedPartitionsSuite.scala | 34 ++++++ 4 files changed, 154 insertions(+), 154 deletions(-) delete mode 100644 core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala create mode 100644 core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala delete mode 100644 core/src/test/scala/spark/MapZippedPartitionsSuite.scala create mode 100644 core/src/test/scala/spark/ZippedPartitionsSuite.scala diff --git a/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala deleted file mode 100644 index 3520fd24b0..0000000000 --- a/core/src/main/scala/spark/rdd/MapZippedPartitionsRDD.scala +++ /dev/null @@ -1,120 +0,0 @@ -package spark.rdd - -import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - -private[spark] class MapZippedPartition( - idx: Int, - @transient rdds: Seq[RDD[_]]) - extends Partition { - - override val index: Int = idx - var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - partitionValues = rdds.map(rdd => rdd.partitions(idx)) - oos.defaultWriteObject() - } -} - -abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( - sc: SparkContext, - var rdds: Seq[RDD[_]]) - extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { - - override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new MapZippedPartition(i, rdds) - } - array - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val splits = s.asInstanceOf[MapZippedPartition].partitions - val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) - preferredLocations.reduce((x, y) => x.intersect(y)) - } - - override def clearDependencies() { - super.clearDependencies() - rdds = null - } -} - -class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions - f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} - -class MapZippedPartitionsRDD3 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - } -} - -class MapZippedPartitionsRDD4 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C], - var rdd4: RDD[D]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context), - rdd4.iterator(partitions(3), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - rdd4 = null - } -} diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala new file mode 100644 index 0000000000..3520fd24b0 --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -0,0 +1,120 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import java.io.{ObjectOutputStream, IOException} + +private[spark] class MapZippedPartition( + idx: Int, + @transient rdds: Seq[RDD[_]]) + extends Partition { + + override val index: Int = idx + var partitionValues = rdds.map(rdd => rdd.partitions(idx)) + def partitions = partitionValues + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + oos.defaultWriteObject() + } +} + +abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[_]]) + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + + override def getPartitions: Array[Partition] = { + val sizes = rdds.map(x => x.partitions.size) + if (!sizes.forall(x => x == sizes(0))) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](sizes(0)) + for (i <- 0 until sizes(0)) { + array(i) = new MapZippedPartition(i, rdds) + } + array + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + val splits = s.asInstanceOf[MapZippedPartition].partitions + val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) + preferredLocations.reduce((x, y) => x.intersect(y)) + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} + +class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} + +class MapZippedPartitionsRDD3 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + } +} + +class MapZippedPartitionsRDD4 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C], + var rdd4: RDD[D]) + extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[MapZippedPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context), + rdd4.iterator(partitions(3), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + rdd4 = null + } +} diff --git a/core/src/test/scala/spark/MapZippedPartitionsSuite.scala b/core/src/test/scala/spark/MapZippedPartitionsSuite.scala deleted file mode 100644 index 834b517cbc..0000000000 --- a/core/src/test/scala/spark/MapZippedPartitionsSuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -package spark - -import scala.collection.immutable.NumericRange - -import org.scalatest.FunSuite -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -import SparkContext._ - - -object MapZippedPartitionsSuite { - def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { - Iterator(i.toArray.size, s.toArray.size, d.toArray.size) - } -} - -class MapZippedPartitionsSuite extends FunSuite with LocalSparkContext { - test("print sizes") { - sc = new SparkContext("local", "test") - val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) - val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) - val data3 = sc.makeRDD(Array(1.0, 2.0), 2) - - val zippedRDD = data1.zipPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) - - val obtainedSizes = zippedRDD.collect() - val expectedSizes = Array(2, 3, 1, 2, 3, 1) - assert(obtainedSizes.size == 6) - assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) - } -} diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala new file mode 100644 index 0000000000..834b517cbc --- /dev/null +++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala @@ -0,0 +1,34 @@ +package spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object MapZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class MapZippedPartitionsSuite extends FunSuite with LocalSparkContext { + test("print sizes") { + sc = new SparkContext("local", "test") + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +} -- cgit v1.2.3 From 15acd49f07c3cde0a381f4abe139b17791a910b4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 28 Apr 2013 16:03:22 -0700 Subject: Actually rename classes to ZippedPartitions* (the previous commit only renamed the file) --- core/src/main/scala/spark/RDD.scala | 18 +++++++-------- .../main/scala/spark/rdd/ZippedPartitionsRDD.scala | 26 +++++++++++----------- .../test/scala/spark/ZippedPartitionsSuite.scala | 6 ++--- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index bded55238f..4310f745f3 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,9 +35,9 @@ import spark.rdd.ShuffledRDD import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD -import spark.rdd.MapZippedPartitionsRDD2 -import spark.rdd.MapZippedPartitionsRDD3 -import spark.rdd.MapZippedPartitionsRDD4 +import spark.rdd.ZippedPartitionsRDD2 +import spark.rdd.ZippedPartitionsRDD3 +import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel import SparkContext._ @@ -441,21 +441,21 @@ abstract class RDD[T: ClassManifest]( def zipPartitions[B: ClassManifest, V: ClassManifest]( f: (Iterator[T], Iterator[B]) => Iterator[V], - rdd2: RDD[B]) = - new MapZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + rdd2: RDD[B]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], rdd2: RDD[B], - rdd3: RDD[C]) = - new MapZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + rdd3: RDD[C]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], rdd2: RDD[B], rdd3: RDD[C], - rdd4: RDD[D]) = - new MapZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + rdd4: RDD[D]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) // Actions (launch a job to return a value to the user program) diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index 3520fd24b0..b3113c1969 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -3,7 +3,7 @@ package spark.rdd import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] class MapZippedPartition( +private[spark] class ZippedPartitions( idx: Int, @transient rdds: Seq[RDD[_]]) extends Partition { @@ -20,7 +20,7 @@ private[spark] class MapZippedPartition( } } -abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( +abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( sc: SparkContext, var rdds: Seq[RDD[_]]) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { @@ -32,13 +32,13 @@ abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( } val array = new Array[Partition](sizes(0)) for (i <- 0 until sizes(0)) { - array(i) = new MapZippedPartition(i, rdds) + array(i) = new ZippedPartitions(i, rdds) } array } override def getPreferredLocations(s: Partition): Seq[String] = { - val splits = s.asInstanceOf[MapZippedPartition].partitions + val splits = s.asInstanceOf[ZippedPartitions].partitions val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) preferredLocations.reduce((x, y) => x.intersect(y)) } @@ -49,15 +49,15 @@ abstract class MapZippedPartitionsBaseRDD[V: ClassManifest]( } } -class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( +class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions + val partitions = s.asInstanceOf[ZippedPartitions].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -68,17 +68,17 @@ class MapZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManife } } -class MapZippedPartitionsRDD3 +class ZippedPartitionsRDD3 [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions + val partitions = s.asInstanceOf[ZippedPartitions].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context)) @@ -92,7 +92,7 @@ class MapZippedPartitionsRDD3 } } -class MapZippedPartitionsRDD4 +class ZippedPartitionsRDD4 [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], @@ -100,10 +100,10 @@ class MapZippedPartitionsRDD4 var rdd2: RDD[B], var rdd3: RDD[C], var rdd4: RDD[D]) - extends MapZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[MapZippedPartition].partitions + val partitions = s.asInstanceOf[ZippedPartitions].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context), diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala index 834b517cbc..5f60aa75d7 100644 --- a/core/src/test/scala/spark/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala @@ -11,20 +11,20 @@ import org.scalacheck.Prop._ import SparkContext._ -object MapZippedPartitionsSuite { +object ZippedPartitionsSuite { def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { Iterator(i.toArray.size, s.toArray.size, d.toArray.size) } } -class MapZippedPartitionsSuite extends FunSuite with LocalSparkContext { +class ZippedPartitionsSuite extends FunSuite with LocalSparkContext { test("print sizes") { sc = new SparkContext("local", "test") val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) val data3 = sc.makeRDD(Array(1.0, 2.0), 2) - val zippedRDD = data1.zipPartitions(MapZippedPartitionsSuite.procZippedData, data2, data3) + val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3) val obtainedSizes = zippedRDD.collect() val expectedSizes = Array(2, 3, 1, 2, 3, 1) -- cgit v1.2.3 From 604d3bf56ce2f77ad391b10842ec1c51daf91a97 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sun, 28 Apr 2013 16:31:07 -0700 Subject: Rename partition class and add scala doc --- core/src/main/scala/spark/RDD.scala | 6 ++++++ core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala | 12 ++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 4310f745f3..09e52ebf3e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -439,6 +439,12 @@ abstract class RDD[T: ClassManifest]( */ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ def zipPartitions[B: ClassManifest, V: ClassManifest]( f: (Iterator[T], Iterator[B]) => Iterator[V], rdd2: RDD[B]): RDD[V] = diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index b3113c1969..fc3f29ffcd 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -3,7 +3,7 @@ package spark.rdd import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] class ZippedPartitions( +private[spark] class ZippedPartitionsPartition( idx: Int, @transient rdds: Seq[RDD[_]]) extends Partition { @@ -32,13 +32,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } val array = new Array[Partition](sizes(0)) for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitions(i, rdds) + array(i) = new ZippedPartitionsPartition(i, rdds) } array } override def getPreferredLocations(s: Partition): Seq[String] = { - val splits = s.asInstanceOf[ZippedPartitions].partitions + val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) preferredLocations.reduce((x, y) => x.intersect(y)) } @@ -57,7 +57,7 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -78,7 +78,7 @@ class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context)) @@ -103,7 +103,7 @@ class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context), -- cgit v1.2.3 From bce4089f22f5e17811f63368b164fae66774095f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 28 Apr 2013 22:23:48 -0700 Subject: Fix BlockManagerSuite to deal with clearing spark.hostPort --- core/src/test/scala/spark/storage/BlockManagerSuite.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b8c0f6fb76..77f444bcad 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -15,6 +15,8 @@ import org.scalatest.time.SpanSugar._ import spark.JavaSerializer import spark.KryoSerializer import spark.SizeEstimator +import spark.Utils +import spark.util.AkkaUtils import spark.util.ByteBufferInputStream class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { @@ -31,7 +33,12 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val serializer = new KryoSerializer before { - actorSystem = ActorSystem("test") + val hostname = Utils.localHostName + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", hostname, 0) + this.actorSystem = actorSystem + System.setProperty("spark.driver.port", boundPort.toString) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + master = new BlockManagerMaster( actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true)))) @@ -44,6 +51,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } after { + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + if (store != null) { store.stop() store = null -- cgit v1.2.3 From 0f45347c7b7243dbf54569f057a3605f96d614af Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 28 Apr 2013 22:29:27 -0700 Subject: More unit test fixes --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 3 +++ core/src/test/scala/spark/storage/BlockManagerSuite.scala | 5 ++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 3abc584b6a..e95818db61 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -81,6 +81,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", "localhost:" + boundPort) + val masterTracker = new MapOutputTracker() masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 77f444bcad..5a11a4483b 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -33,11 +33,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val serializer = new KryoSerializer before { - val hostname = Utils.localHostName - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", hostname, 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) this.actorSystem = actorSystem System.setProperty("spark.driver.port", boundPort.toString) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true)))) -- cgit v1.2.3 From 224fbac0612d5c35259cc9f4963dcd4a65ecc832 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Apr 2013 10:10:14 -0700 Subject: Spark-742: TaskMetrics should not employ per-record timing. This patch does three things: 1. Makes TimedIterator a trait with two implementations (one a no-op) 2. Makes the default behavior to use the no-op implementation 3. Removes DelegateBlockFetchTracker. This is just cleanup, but it seems like the triat doesn't really reduce complexity in any way. In the future we can add other implementations, e.g. ones which perform sampling. --- .../scala/spark/BlockStoreShuffleFetcher.scala | 23 ++++++++++--------- .../main/scala/spark/executor/TaskMetrics.scala | 2 +- .../spark/storage/DelegateBlockFetchTracker.scala | 12 ---------- core/src/main/scala/spark/util/TimedIterator.scala | 26 ++++++++++++++++++---- .../scala/spark/scheduler/SparkListenerSuite.scala | 2 +- 5 files changed, 37 insertions(+), 28 deletions(-) delete mode 100644 core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index c27ed36406..83c22b1f14 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -4,8 +4,8 @@ import executor.{ShuffleReadMetrics, TaskMetrics} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import spark.storage.{DelegateBlockFetchTracker, BlockManagerId} -import util.{CompletionIterator, TimedIterator} +import spark.storage.BlockManagerId +import util.{NoOpTimedIterator, SystemTimedIterator, CompletionIterator, TimedIterator} private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { @@ -49,17 +49,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } val blockFetcherItr = blockManager.getMultiple(blocksByAddress) - val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker - itr.setDelegate(blockFetcherItr) + val itr = if (System.getProperty("per.record.shuffle.metrics", "false").toBoolean) { + new SystemTimedIterator(blockFetcherItr.flatMap(unpackBlock)) + } else { + new NoOpTimedIterator(blockFetcherItr.flatMap(unpackBlock)) + } CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleReadMillis = itr.getNetMillis - shuffleMetrics.remoteFetchTime = itr.remoteFetchTime - shuffleMetrics.fetchWaitTime = itr.fetchWaitTime - shuffleMetrics.remoteBytesRead = itr.remoteBytesRead - shuffleMetrics.totalBlocksFetched = itr.totalBlocks - shuffleMetrics.localBlocksFetched = itr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks + shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime + shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime + shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead + shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks + shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks + shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks metrics.shuffleReadMetrics = Some(shuffleMetrics) }) } diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index 93bbb6b458..45f6d43971 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -51,7 +51,7 @@ class ShuffleReadMetrics extends Serializable { /** * Total time to read shuffle data */ - var shuffleReadMillis: Long = _ + var shuffleReadMillis: Option[Long] = _ /** * Total time that is spent blocked waiting for shuffle to fetch data diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala deleted file mode 100644 index f6c28dce52..0000000000 --- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala +++ /dev/null @@ -1,12 +0,0 @@ -package spark.storage - -private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker { - var delegate : BlockFetchTracker = _ - def setDelegate(d: BlockFetchTracker) {delegate = d} - def totalBlocks = delegate.totalBlocks - def numLocalBlocks = delegate.numLocalBlocks - def numRemoteBlocks = delegate.numRemoteBlocks - def remoteFetchTime = delegate.remoteFetchTime - def fetchWaitTime = delegate.fetchWaitTime - def remoteBytesRead = delegate.remoteBytesRead -} diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala index 539b01f4ce..49f1276b4e 100644 --- a/core/src/main/scala/spark/util/TimedIterator.scala +++ b/core/src/main/scala/spark/util/TimedIterator.scala @@ -1,13 +1,21 @@ package spark.util /** - * A utility for tracking the total time an iterator takes to iterate through its elements. + * A utility for tracking the the time an iterator takes to iterate through its elements. + */ +trait TimedIterator { + def getNetMillis: Option[Long] + def getAverageTimePerItem: Option[Double] +} + +/** + * A TimedIterator which uses System.currentTimeMillis() on every call to next(). * * In general, this should only be used if you expect it to take a considerable amount of time * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate, * and you are probably just adding more overhead */ -class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] { +class SystemTimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] with TimedIterator { private var netMillis = 0l private var nElems = 0 def hasNext = { @@ -26,7 +34,17 @@ class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] { r } - def getNetMillis = netMillis - def getAverageTimePerItem = netMillis / nElems.toDouble + def getNetMillis = Some(netMillis) + def getAverageTimePerItem = Some(netMillis / nElems.toDouble) } + +/** + * A TimedIterator which doesn't perform any timing measurements. + */ +class NoOpTimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] with TimedIterator { + def hasNext = sub.hasNext + def next = sub.next + def getNetMillis = None + def getAverageTimePerItem = None +} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 2f5af10e69..5ccab369db 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -57,7 +57,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get sm.totalBlocksFetched should be > (0) - sm.shuffleReadMillis should be > (0l) + sm.shuffleReadMillis.get should be > (0l) sm.localBlocksFetched should be > (0) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) -- cgit v1.2.3 From 540be6b1544d26c7db79ec84a98fc6696c7c6434 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Apr 2013 11:32:07 -0700 Subject: Modified version of the fix which just removes all per-record tracking. --- .../scala/spark/BlockStoreShuffleFetcher.scala | 9 +--- .../main/scala/spark/executor/TaskMetrics.scala | 5 --- core/src/main/scala/spark/util/TimedIterator.scala | 50 ---------------------- .../scala/spark/scheduler/SparkListenerSuite.scala | 1 - 4 files changed, 2 insertions(+), 63 deletions(-) delete mode 100644 core/src/main/scala/spark/util/TimedIterator.scala diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 83c22b1f14..ce61d27448 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -5,7 +5,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import spark.storage.BlockManagerId -import util.{NoOpTimedIterator, SystemTimedIterator, CompletionIterator, TimedIterator} +import util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { @@ -49,14 +49,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin } val blockFetcherItr = blockManager.getMultiple(blocksByAddress) - val itr = if (System.getProperty("per.record.shuffle.metrics", "false").toBoolean) { - new SystemTimedIterator(blockFetcherItr.flatMap(unpackBlock)) - } else { - new NoOpTimedIterator(blockFetcherItr.flatMap(unpackBlock)) - } + val itr = blockFetcherItr.flatMap(unpackBlock) CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleReadMillis = itr.getNetMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index 45f6d43971..a7c56c2371 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -48,11 +48,6 @@ class ShuffleReadMetrics extends Serializable { */ var localBlocksFetched: Int = _ - /** - * Total time to read shuffle data - */ - var shuffleReadMillis: Option[Long] = _ - /** * Total time that is spent blocked waiting for shuffle to fetch data */ diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala deleted file mode 100644 index 49f1276b4e..0000000000 --- a/core/src/main/scala/spark/util/TimedIterator.scala +++ /dev/null @@ -1,50 +0,0 @@ -package spark.util - -/** - * A utility for tracking the the time an iterator takes to iterate through its elements. - */ -trait TimedIterator { - def getNetMillis: Option[Long] - def getAverageTimePerItem: Option[Double] -} - -/** - * A TimedIterator which uses System.currentTimeMillis() on every call to next(). - * - * In general, this should only be used if you expect it to take a considerable amount of time - * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate, - * and you are probably just adding more overhead - */ -class SystemTimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] with TimedIterator { - private var netMillis = 0l - private var nElems = 0 - def hasNext = { - val start = System.currentTimeMillis() - val r = sub.hasNext - val end = System.currentTimeMillis() - netMillis += (end - start) - r - } - def next = { - val start = System.currentTimeMillis() - val r = sub.next - val end = System.currentTimeMillis() - netMillis += (end - start) - nElems += 1 - r - } - - def getNetMillis = Some(netMillis) - def getAverageTimePerItem = Some(netMillis / nElems.toDouble) - -} - -/** - * A TimedIterator which doesn't perform any timing measurements. - */ -class NoOpTimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] with TimedIterator { - def hasNext = sub.hasNext - def next = sub.next - def getNetMillis = None - def getAverageTimePerItem = None -} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 5ccab369db..42a87d8b90 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -57,7 +57,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get sm.totalBlocksFetched should be > (0) - sm.shuffleReadMillis.get should be > (0l) sm.localBlocksFetched should be > (0) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) -- cgit v1.2.3 From 016ce1fa9c9ebbe45559b1cbd95a3674510fe880 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Apr 2013 12:02:27 -0700 Subject: Using full package name for util --- core/src/main/scala/spark/BlockStoreShuffleFetcher.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index ce61d27448..2987dbbe58 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -5,7 +5,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import spark.storage.BlockManagerId -import util.CompletionIterator +import spark.util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = { -- cgit v1.2.3 From f1f92c88eb2960a16d33bf7dd291c8ce58f665de Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 29 Apr 2013 17:08:45 -0700 Subject: Build against Hadoop 1 by default --- project/SparkBuild.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7bd6c4c235..f2410085d8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -11,8 +11,9 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - //val HADOOP_VERSION = "1.0.4" - //val HADOOP_MAJOR_VERSION = "1" + val HADOOP_VERSION = "1.0.4" + val HADOOP_MAJOR_VERSION = "1" + val HADOOP_YARN = false // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" @@ -20,10 +21,9 @@ object SparkBuild extends Build { //val HADOOP_YARN = false // For Hadoop 2 YARN support - // val HADOOP_VERSION = "0.23.7" - val HADOOP_VERSION = "2.0.2-alpha" - val HADOOP_MAJOR_VERSION = "2" - val HADOOP_YARN = true + //val HADOOP_VERSION = "2.0.2-alpha" + //val HADOOP_MAJOR_VERSION = "2" + //val HADOOP_YARN = true lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming) -- cgit v1.2.3 From 7007201201981c6fb002e3008d97a6d6248f4dba Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Apr 2013 23:07:03 -0700 Subject: Added a shuffle block manager so it is easier in the future to consolidate shuffle output files. --- .../scala/spark/scheduler/ShuffleMapTask.scala | 50 +++++++++++------ .../main/scala/spark/storage/BlockManager.scala | 16 ++++-- .../scala/spark/storage/BlockObjectWriter.scala | 43 ++++++++++---- core/src/main/scala/spark/storage/DiskStore.scala | 65 +++++++++++++++------- .../scala/spark/storage/ShuffleBlockManager.scala | 52 +++++++++++++++++ 5 files changed, 175 insertions(+), 51 deletions(-) create mode 100644 core/src/main/scala/spark/storage/ShuffleBlockManager.scala diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 51ec89eb74..124d2d7e26 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -86,8 +86,14 @@ private[spark] class ShuffleMapTask( protected def this() = this(0, null, null, 0, null) - // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. - private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + // Data locality is on a per host basis, not hyper specific to container (host:port). + // Unique on set of hosts. + // TODO(rxin): The above statement seems problematic. Even if partitions are on the same host, + // the worker would still need to serialize / deserialize those data when they are in + // different jvm processes. Often that is very costly ... + @transient + private val preferredLocs: Seq[String] = + if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq { // DEBUG code @@ -131,31 +137,32 @@ private[spark] class ShuffleMapTask( val taskContext = new TaskContext(stageId, partition, attemptId) metrics = Some(taskContext.taskMetrics) + + val blockManager = SparkEnv.get.blockManager + var shuffle: ShuffleBlockManager#Shuffle = null + var buckets: ShuffleWriterGroup = null + try { // Obtain all the block writers for shuffle blocks. - val blockManager = SparkEnv.get.blockManager - val buckets = Array.tabulate[BlockObjectWriter](numOutputSplits) { bucketId => - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + bucketId - blockManager.getDiskBlockWriter(blockId, Serializer.get(dep.serializerClass)) - } + val ser = Serializer.get(dep.serializerClass) + shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) + buckets = shuffle.acquireWriters(partition) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, taskContext)) { val pair = elem.asInstanceOf[(Any, Any)] val bucketId = dep.partitioner.getPartition(pair._1) - buckets(bucketId).write(pair) + buckets.writers(bucketId).write(pair) } - // Close the bucket writers and get the sizes of each block. - val compressedSizes = new Array[Byte](numOutputSplits) - var i = 0 + // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L - while (i < numOutputSplits) { - buckets(i).close() - val size = buckets(i).size() + val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.size() totalBytes += size - compressedSizes(i) = MapOutputTracker.compressSize(size) - i += 1 + MapOutputTracker.compressSize(size) } // Update shuffle metrics. @@ -164,7 +171,18 @@ private[spark] class ShuffleMapTask( metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) return new MapStatus(blockManager.blockManagerId, compressedSizes) + } catch { case e: Exception => + // If there is an exception from running the task, revert the partial writes + // and throw the exception upstream to Spark. + if (buckets != null) { + buckets.writers.foreach(_.revertPartialWrites()) + } + throw e } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && buckets != null) { + shuffle.releaseWriters(buckets) + } // Execute the callbacks on task completion. taskContext.executeOnCompleteCallbacks() } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 9190c96c71..b94d729923 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -88,6 +88,8 @@ class BlockManager( } } + val shuffleBlockManager = new ShuffleBlockManager(this) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) @@ -391,7 +393,7 @@ class BlockManager( // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.startsWith("shuffle_")) { + if (ShuffleBlockManager.isShuffle(blockId)) { return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -508,12 +510,12 @@ class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * This is currently used for writing shuffle files out. + * This is currently used for writing shuffle files out. Callers should handle error + * cases. */ def getDiskBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { val writer = diskStore.getBlockWriter(blockId, serializer) writer.registerCloseEventHandler(() => { - // TODO(rxin): This doesn't handle error cases. val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) blockInfo.put(blockId, myInfo) myInfo.markReady(writer.size()) @@ -872,7 +874,7 @@ class BlockManager( } def shouldCompress(blockId: String): Boolean = { - if (blockId.startsWith("shuffle_")) { + if (ShuffleBlockManager.isShuffle(blockId)) { compressShuffle } else if (blockId.startsWith("broadcast_")) { compressBroadcast @@ -887,7 +889,11 @@ class BlockManager( * Wrap an output stream for compression if block compression is enabled for its block type */ def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) new LZFOutputStream(s) else s + if (shouldCompress(blockId)) { + (new LZFOutputStream(s)).setFinishBlockOnFlush(true) + } else { + s + } } /** diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala index 657a7e9143..42e2b07d5c 100644 --- a/core/src/main/scala/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala @@ -3,25 +3,48 @@ package spark.storage import java.nio.ByteBuffer +/** + * An interface for writing JVM objects to some underlying storage. This interface allows + * appending data to an existing block, and can guarantee atomicity in the case of faults + * as it allows the caller to revert partial writes. + * + * This interface does not support concurrent writes. + */ abstract class BlockObjectWriter(val blockId: String) { - // TODO(rxin): What if there is an exception when the block is being written out? - var closeEventHandler: () => Unit = _ - def registerCloseEventHandler(handler: () => Unit) { - closeEventHandler = handler + def open(): BlockObjectWriter + + def close() { + closeEventHandler() } - def write(value: Any) + def isOpen: Boolean - def writeAll(value: Iterator[Any]) { - value.foreach(write) + def registerCloseEventHandler(handler: () => Unit) { + closeEventHandler = handler } - def close() { - closeEventHandler() - } + /** + * Flush the partial writes and commit them as a single atomic block. Return the + * number of bytes written for this commit. + */ + def commit(): Long + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. + */ + def revertPartialWrites() + + /** + * Writes an object. + */ + def write(value: Any) + /** + * Size of the valid writes, in bytes. + */ def size(): Long } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index b527a3c708..f23cd5475f 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -12,7 +12,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark.Utils import spark.executor.ExecutorExitCode -import spark.serializer.Serializer +import spark.serializer.{Serializer, SerializationStream} /** @@ -21,35 +21,58 @@ import spark.serializer.Serializer private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { - private val mapMode = MapMode.READ_ONLY - private var mapOpenMode = "r" - class DiskBlockObjectWriter(blockId: String, serializer: Serializer) extends BlockObjectWriter(blockId) { private val f: File = createFile(blockId /*, allowAppendExisting */) - private val bs: OutputStream = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(f))) - private val objOut = serializer.newInstance().serializeStream(bs) - - private var _size: Long = -1L - override def write(value: Any) { - objOut.writeObject(value) + private var repositionableStream: FastBufferedOutputStream = null + private var bs: OutputStream = null + private var objOut: SerializationStream = null + private var validLength = 0L + + override def open(): DiskBlockObjectWriter = { + println("------------------------------------------------- opening " + f) + repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f)) + bs = blockManager.wrapForCompression(blockId, repositionableStream) + objOut = serializer.newInstance().serializeStream(bs) + this } override def close() { objOut.close() bs.close() + objOut = null + bs = null + repositionableStream = null + // Invoke the close callback handler. super.close() } - override def size(): Long = { - if (_size < 0) { - _size = f.length() - } - _size + override def isOpen: Boolean = objOut != null + + // Flush the partial writes, and set valid length to be the length of the entire file. + // Return the number of bytes written for this commit. + override def commit(): Long = { + bs.flush() + repositionableStream.position() + } + + override def revertPartialWrites() { + // Flush the outstanding writes and delete the file. + objOut.close() + bs.close() + objOut = null + bs = null + repositionableStream = null + f.delete() + } + + override def write(value: Any) { + objOut.writeObject(value) } + + override def size(): Long = validLength } val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -90,9 +113,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def getFileBytes(file: File): ByteBuffer = { val length = file.length() - val channel = new RandomAccessFile(file, mapOpenMode).getChannel() + val channel = new RandomAccessFile(file, "r").getChannel() val buffer = try { - channel.map(mapMode, 0, length) + channel.map(MapMode.READ_ONLY, 0, length) } finally { channel.close() } @@ -230,12 +253,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) ) + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") try { - localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)) + localDirs.foreach { localDir => + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala new file mode 100644 index 0000000000..2b1138e7a0 --- /dev/null +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -0,0 +1,52 @@ +package spark.storage + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer + +import spark.serializer.Serializer + + +private[spark] +class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) + + +private[spark] +class ShuffleBlockManager(blockManager: BlockManager) { + + val shuffles = new ConcurrentHashMap[Int, Shuffle] + + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = { + new Shuffle(shuffleId, numBuckets, serializer) + } + + class Shuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) { + + // Get a group of writers for a map task. + def acquireWriters(mapId: Int): ShuffleWriterGroup = { + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) + blockManager.getDiskBlockWriter(blockId, serializer).open() + } + new ShuffleWriterGroup(mapId, writers) + } + + def releaseWriters(group: ShuffleWriterGroup) = { + // Nothing really to release here. + } + } +} + + +private[spark] +object ShuffleBlockManager { + + // Returns the block id for a given shuffle block. + def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { + "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + } + + // Returns true if the block is a shuffle block. + def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") +} -- cgit v1.2.3 From 1055785a836ab2361239f0937a1a22fee953e029 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Apr 2013 23:33:56 -0700 Subject: Allow specifying the shuffle write file buffer size. The default buffer size is 8KB in FastBufferedOutputStream, which is too small and would cause a lot of disk seeks. --- core/src/main/scala/spark/storage/BlockManager.scala | 5 +++-- core/src/main/scala/spark/storage/DiskStore.scala | 10 +++++----- core/src/main/scala/spark/storage/ShuffleBlockManager.scala | 3 ++- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index b94d729923..6e0ca9204d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -513,8 +513,9 @@ class BlockManager( * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer) + def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { + val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) writer.registerCloseEventHandler(() => { val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) blockInfo.put(blockId, myInfo) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index f23cd5475f..4cddcc86fc 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -21,7 +21,7 @@ import spark.serializer.{Serializer, SerializationStream} private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { - class DiskBlockObjectWriter(blockId: String, serializer: Serializer) + class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) extends BlockObjectWriter(blockId) { private val f: File = createFile(blockId /*, allowAppendExisting */) @@ -32,7 +32,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var validLength = 0L override def open(): DiskBlockObjectWriter = { - println("------------------------------------------------- opening " + f) repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f)) bs = blockManager.wrapForCompression(blockId, repositionableStream) objOut = serializer.newInstance().serializeStream(bs) @@ -55,7 +54,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Return the number of bytes written for this commit. override def commit(): Long = { bs.flush() - repositionableStream.position() + validLength = repositionableStream.position() + validLength } override def revertPartialWrites() { @@ -86,8 +86,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() - def getBlockWriter(blockId: String, serializer: Serializer): BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer) + def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int): BlockObjectWriter = { + new DiskBlockObjectWriter(blockId, serializer, bufferSize) } override def getSize(blockId: String): Long = { diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 2b1138e7a0..2b22dad459 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -25,9 +25,10 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Get a group of writers for a map task. def acquireWriters(mapId: Int): ShuffleWriterGroup = { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() } new ShuffleWriterGroup(mapId, writers) } -- cgit v1.2.3 From e46d547ccd43c0fb3a79a30a7c43a78afba6f93f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 30 Apr 2013 16:15:56 +0530 Subject: Fix issues reported by Reynold --- .../scala/spark/network/ConnectionManager.scala | 64 ++++++++++++++++++---- run | 7 ++- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0c6bdb1559..a79fce8697 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -188,6 +188,38 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } ) } + // MUST be called within selector loop - else deadlock. + private def triggerForceCloseByException(key: SelectionKey, e: Exception) { + try { + key.interestOps(0) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + // Pushing to connect threadpool + handleConnectExecutor.execute(new Runnable { + override def run() { + try { + conn.callOnExceptionCallback(e) + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + try { + conn.close() + } catch { + // ignore exceptions + case e: Exception => logDebug("Ignoring exception", e) + } + } + }) + } + + def run() { try { while(!selectorThread.isInterrupted) { @@ -235,18 +267,26 @@ private[spark] class ConnectionManager(port: Int) extends Logging { while (selectedKeys.hasNext()) { val key = selectedKeys.next selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) + try { + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) + } + } + } catch { + // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) } } } diff --git a/run b/run index 756f8703f2..0a58ac4a36 100755 --- a/run +++ b/run @@ -95,6 +95,7 @@ export JAVA_OPTS CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" +REPL_BIN_DIR="$FWDIR/repl-bin" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" STREAMING_DIR="$FWDIR/streaming" @@ -125,8 +126,8 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH+=":$FWDIR/lib_managed/bundles/*" fi CLASSPATH+=":$REPL_DIR/lib/*" -if [ -e repl-bin/target ]; then - for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do +if [ -e $REPL_BIN_DIR/target ]; then + for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH+=":$jar" done fi @@ -134,7 +135,6 @@ CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do CLASSPATH+=":$jar" done -export CLASSPATH # Needed for spark-shell # Figure out the JAR file that our examples were packaged into. This includes a bit of a hack # to avoid the -sources and -doc packages that are built by publish-local. @@ -163,4 +163,5 @@ else EXTRA_ARGS="$JAVA_OPTS" fi +export CLASSPATH # Needed for spark-shell exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@" -- cgit v1.2.3 From 48854e1dbf1d02e1e19f59d0aee0e281d41b3b45 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Tue, 30 Apr 2013 23:59:33 +0530 Subject: If key is not valid, close connection --- .gitignore | 2 -- core/src/main/scala/spark/network/ConnectionManager.scala | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 155e785b01..5bb2f33574 100644 --- a/.gitignore +++ b/.gitignore @@ -29,8 +29,6 @@ project/build/target/ project/plugins/target/ project/plugins/lib_managed/ project/plugins/src_managed/ -logs/ -log/ spark-tests.log streaming-tests.log dependency-reduced-pom.xml diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index a79fce8697..2d9b4be4b3 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -281,6 +281,9 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (key.isWritable) { triggerWrite(key) } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() } } catch { // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. -- cgit v1.2.3 From 538614acfe95b0c064679122af3bc990b669e4e0 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 00:05:32 +0530 Subject: Be more aggressive and defensive in select also --- .../scala/spark/network/ConnectionManager.scala | 83 ++++++++++++++-------- 1 file changed, 55 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 2d9b4be4b3..9b00fddd40 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -254,7 +254,32 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - val selectedKeysCount = selector.select() + val selectedKeysCount = + try { + selector.select() + } catch { + case e: CancelledKeyException => { + // Some keys within the selectors list are invalid/closed. clear them. + val allKeys = selector.keys().iterator() + + while (allKeys.hasNext()) { + val key = allKeys.next() + try { + if (! key.isValid) { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } + } + } + } + 0 + } + if (selectedKeysCount == 0) { logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") } @@ -262,34 +287,36 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("Selector thread was interrupted!") return } - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) + + if (0 != selectedKeysCount) { + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext()) { + val key = selectedKeys.next + selectedKeys.remove() + try { + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) } } } -- cgit v1.2.3 From 0f45477be16254971763cbc07feac7460cffd0bd Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 00:10:02 +0530 Subject: Change indentation --- .../scala/spark/network/ConnectionManager.scala | 40 +++++++++++----------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 9b00fddd40..925d076951 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -255,30 +255,30 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selectedKeysCount = - try { - selector.select() - } catch { - case e: CancelledKeyException => { - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext()) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) + try { + selector.select() + } catch { + case e: CancelledKeyException => { + // Some keys within the selectors list are invalid/closed. clear them. + val allKeys = selector.keys().iterator() + + while (allKeys.hasNext()) { + val key = allKeys.next() + try { + if (! key.isValid) { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) + } } } } + 0 } - 0 - } if (selectedKeysCount == 0) { logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") -- cgit v1.2.3 From 3b748ced2258246bd9b7c250363645cea27cf622 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 00:30:30 +0530 Subject: Be more aggressive and defensive in all uses of SelectionKey in select loop --- .../scala/spark/network/ConnectionManager.scala | 47 ++++++++++++++-------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 925d076951..03926a6038 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -232,24 +232,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging { while(!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + try { + if (key.isValid) { + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } + } + } else { + logInfo("Key not valid ? " + key) + throw new CancelledKeyException() + } + } catch { + case e: CancelledKeyException => { + logInfo("key already cancelled ? " + key, e) + triggerForceCloseByException(key, e) } } } -- cgit v1.2.3 From c446ac31d7065d227f168a7f27010bdf98ef7ad1 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 00:32:30 +0530 Subject: Spurious commit, reverting gitignore change --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 5bb2f33574..155e785b01 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,8 @@ project/build/target/ project/plugins/target/ project/plugins/lib_managed/ project/plugins/src_managed/ +logs/ +log/ spark-tests.log streaming-tests.log dependency-reduced-pom.xml -- cgit v1.2.3 From 60cabb35cbfd2af0e5ba34c4a416aa2640091acc Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 01:17:14 +0530 Subject: Add addition catch block for exception too --- core/src/main/scala/spark/network/ConnectionManager.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 03926a6038..0eb03630d0 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -264,6 +264,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("key already cancelled ? " + key, e) triggerForceCloseByException(key, e) } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } } } @@ -271,6 +275,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { try { selector.select() } catch { + // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently. case e: CancelledKeyException => { // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() @@ -287,6 +292,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("key already cancelled ? " + key, e) triggerForceCloseByException(key, e) } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } } } } @@ -330,6 +339,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("key already cancelled ? " + key, e) triggerForceCloseByException(key, e) } + case e: Exception => { + logError("Exception processing key " + key, e) + triggerForceCloseByException(key, e) + } } } } -- cgit v1.2.3 From dd7bef31472e8c7dedc93bc1519be5900784c736 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Apr 2013 15:02:32 -0700 Subject: Two minor fixes according to Ryan LeCompte's review. --- core/src/main/scala/spark/storage/BlockManager.scala | 7 ++----- core/src/main/scala/spark/storage/ShuffleBlockManager.scala | 7 ------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 6e0ca9204d..09572b19db 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -296,11 +296,8 @@ class BlockManager( * never deletes (recent) items. */ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer) match { - case Some(iterator) => Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } + diskStore.getValues(blockId, serializer).orElse( + sys.error("Block " + blockId + " not found on disk, though it should be")) } /** diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 2b22dad459..1903df0817 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -1,10 +1,5 @@ package spark.storage -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.ArrayBuffer - import spark.serializer.Serializer @@ -15,8 +10,6 @@ class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) private[spark] class ShuffleBlockManager(blockManager: BlockManager) { - val shuffles = new ConcurrentHashMap[Int, Shuffle] - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = { new Shuffle(shuffleId, numBuckets, serializer) } -- cgit v1.2.3 From 1d54401d7e41095d8cbeeefd42c9d39ee500cd9f Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 30 Apr 2013 23:01:32 -0600 Subject: Modified as per TD's suggestions --- ...etworkWordCumulativeCountUpdateStateByKey.scala | 63 ---------------------- .../examples/StatefulNetworkWordCount.scala | 52 ++++++++++++++++++ 2 files changed, 52 insertions(+), 63 deletions(-) delete mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala create mode 100644 examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala deleted file mode 100644 index db62246387..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCumulativeCountUpdateStateByKey.scala +++ /dev/null @@ -1,63 +0,0 @@ -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ - -/** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: NetworkWordCumulativeCountUpdateStateByKey - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run spark.streaming.examples.NetworkWordCumulativeCountUpdateStateByKey local[2] localhost 9999` - */ -object NetworkWordCumulativeCountUpdateStateByKey { - private def className[A](a: A)(implicit m: Manifest[A]) = m.toString - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCountUpdateStateByKey \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.foldLeft(0)(_ + _) - //println("currentCount: " + currentCount) - - val previousCount = state.getOrElse(0) - //println("previousCount: " + previousCount) - - val cumulative = Some(currentCount + previousCount) - //println("Cumulative: " + cumulative) - - cumulative - } - - // Create the context with a 10 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(10), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - ssc.checkpoint(".") - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordDstream = words.map(x => (x, 1)) - - // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) - - stateDstream.foreach(rdd => { - rdd.foreach(rddVal => { - println("Current Count: " + rddVal) - }) - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala new file mode 100644 index 0000000000..b662cb1162 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -0,0 +1,52 @@ +package spark.streaming.examples + +import spark.streaming._ +import spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: StatefulNetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` + */ +object StatefulNetworkWordCount { + private def className[A](a: A)(implicit m: Manifest[A]) = m.toString + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: StatefulNetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + val currentCount = values.foldLeft(0)(_ + _) + + val previousCount = state.getOrElse(0) + + Some(currentCount + previousCount) + } + + // Create the context with a 10 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(10), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint(".") + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordDstream = words.map(x => (x, 1)) + + // Update the cumulative count using updateStateByKey + // This will give a Dstream made of state (which is the cumulative count of the words) + val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + stateDstream.print() + ssc.start() + } +} -- cgit v1.2.3 From d960e7e0f83385d8f43129d53c189b3036936daf Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 20:24:00 +0530 Subject: a) Add support for hyper local scheduling - specific to a host + port - before trying host local scheduling. b) Add some fixes to test code to ensure it passes (and fixes some other issues). c) Fix bug in task scheduling which incorrectly used availableCores instead of all cores on the node. --- core/src/main/scala/spark/SparkEnv.scala | 21 ++- core/src/main/scala/spark/Utils.scala | 17 ++- .../main/scala/spark/deploy/worker/Worker.scala | 5 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- core/src/main/scala/spark/rdd/ZippedRDD.scala | 2 + .../main/scala/spark/scheduler/DAGScheduler.scala | 7 +- .../main/scala/spark/scheduler/ResultTask.scala | 4 +- .../scala/spark/scheduler/ShuffleMapTask.scala | 4 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 64 ++++++++- .../spark/scheduler/cluster/TaskSetManager.scala | 155 ++++++++++++++------- .../main/scala/spark/storage/BlockManager.scala | 43 ++++-- .../test/scala/spark/MapOutputTrackerSuite.scala | 6 +- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 4 +- .../scala/spark/storage/BlockManagerSuite.scala | 2 + 14 files changed, 244 insertions(+), 99 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index ffb40bab3a..5b4a464010 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -29,7 +29,11 @@ class SparkEnv ( val blockManager: BlockManager, val connectionManager: ConnectionManager, val httpFileServer: HttpFileServer, - val sparkFilesDir: String + val sparkFilesDir: String, + // To be set only as part of initialization of SparkContext. + // (executorId, defaultHostPort) => executorHostPort + // If executorId is NOT found, return defaultHostPort + var executorIdToHostPort: (String, String) => String ) { def stop() { @@ -44,6 +48,17 @@ class SparkEnv ( // down, but let's call it anyway in case it gets fixed in a later release actorSystem.awaitTermination() } + + + def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { + val env = SparkEnv.get + if (env.executorIdToHostPort == null) { + // default to using host, not host port. Relevant to non cluster modes. + return defaultHostPort + } + + env.executorIdToHostPort(executorId, defaultHostPort) + } } object SparkEnv extends Logging { @@ -162,7 +177,7 @@ object SparkEnv extends Logging { blockManager, connectionManager, httpFileServer, - sparkFilesDir) + sparkFilesDir, + null) } - } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 9f48cbe490..279daf04ed 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -357,21 +357,26 @@ private object Utils extends Logging { Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) } } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + // temp code for debug + System.exit(-1) + } */ // Once testing is complete in various modes, replace with this ? def checkHost(host: String, message: String = "") {} def checkHostPort(hostPort: String, message: String = "") {} - def getUserNameFromEnvironment(): String = { - SparkHadoopUtil.getUserNameFromEnvironment - } - // Used by DEBUG code : remove when all testing done def logErrorWithStack(msg: String) { try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } - // temp code for debug - System.exit(-1) + } + + def getUserNameFromEnvironment(): String = { + SparkHadoopUtil.getUserNameFromEnvironment } // Typically, this will be of order of number of nodes in cluster diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 1a7da0f7bf..3dc2207170 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -54,7 +54,10 @@ private[spark] class Worker( def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { - if ( (workDir.exists() && !workDir.isDirectory) || (!workDir.exists() && !workDir.mkdirs()) ) { + // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() + // So attempting to create and then check if directory was created or not. + workDir.mkdirs() + if ( !workDir.exists() || !workDir.isDirectory) { logError("Failed to create work directory " + workDir) System.exit(1) } diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 7348c4f15b..719d4bf03e 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -1,7 +1,7 @@ package spark.rdd -import scala.collection.mutable.HashMap import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} +import spark.storage.BlockManager private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { val index = idx @@ -11,12 +11,7 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient lazy val locations_ = { - val blockManager = SparkEnv.get.blockManager - /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locations = blockManager.getLocations(blockIds) - HashMap(blockIds.zip(locations):_*) - } + @transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get) override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 35b0e06785..e80250a99b 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -49,6 +49,8 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( override def getPreferredLocations(s: Partition): Seq[String] = { val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions + // TODO: becomes complicated - intersect on hostPort if available, else fallback to host (removing intersected hostPort's). + // Since I am not very sure about this RDD, leaving it to others to comment better ! rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2)) } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 1440b93e65..8072c60bb7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -12,7 +12,7 @@ import spark.executor.TaskMetrics import spark.partial.ApproximateActionListener import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.storage.BlockManagerMaster +import spark.storage.{BlockManager, BlockManagerMaster} import spark.util.{MetadataCleaner, TimeStampedHashMap} /** @@ -117,9 +117,8 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { - locations => locations.map(_.hostPort).toList - }.toArray + val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env) + cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil)) } cacheLocs(rdd.id) } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 89dc6640b2..c43cbe5ed4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -71,11 +71,11 @@ private[spark] class ResultTask[T, U]( } // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. - val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq { // DEBUG code - preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) } override def run(attemptId: Long): U = { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 7dc6da4573..0b848af2f3 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -85,11 +85,11 @@ private[spark] class ShuffleMapTask( protected def this() = this(0, null, null, 0, null) // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. - private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq { // DEBUG code - preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) } var split = if (rdd == null) { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index a9d9c5e44c..3c72ce4206 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -79,9 +79,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - val executorsByHostPort = new HashMap[String, HashSet[String]] + private val executorsByHostPort = new HashMap[String, HashSet[String]] - val executorIdToHostPort = new HashMap[String, String] + private val executorIdToHostPort = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -102,6 +102,14 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def initialize(context: SchedulerBackend) { backend = context + // resolve executorId to hostPort mapping. + def executorToHostPort(executorId: String, defaultHostPort: String): String = { + executorIdToHostPort.getOrElse(executorId, defaultHostPort) + } + + // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler + // Will that be a design violation ? + SparkEnv.get.executorIdToHostPort = executorToHostPort } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -209,13 +217,30 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + // merge availableCpus into hostToAvailableCpus block ? val availableCpus = offers.map(o => o.cores).toArray + val hostToAvailableCpus = { + val map = new HashMap[String, Int]() + for (offer <- offers) { + val hostPort = offer.hostPort + val cores = offer.cores + // DEBUG code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + + map.put(host, map.getOrElse(host, 0) + cores) + } + + map + } var launchedTask = false for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { // Split offers based on host local, rack local and off-rack tasks. + val hyperLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val otherOffers = new HashMap[String, ArrayBuffer[Int]]() @@ -224,8 +249,17 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val hostPort = offers(i).hostPort // DEBUG code Utils.checkHostPort(hostPort) + + val numHyperLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) + if (numHyperLocalTasks > 0){ + val list = hyperLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) + for (j <- 0 until numHyperLocalTasks) list += i + } + val host = Utils.parseHostPort(hostPort)._1 - val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i))) + val numHostLocalTasks = math.max(0, + // Remove hyper local tasks (which are also host local btw !) from this + math.min(manager.numPendingTasksForHost(hostPort) - numHyperLocalTasks, hostToAvailableCpus(host))) if (numHostLocalTasks > 0){ val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) for (j <- 0 until numHostLocalTasks) list += i @@ -233,7 +267,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val numRackLocalTasks = math.max(0, // Remove host local tasks (which are also rack local btw !) from this - math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i))) + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHyperLocalTasks - numHostLocalTasks, hostToAvailableCpus(host))) if (numRackLocalTasks > 0){ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) for (j <- 0 until numRackLocalTasks) list += i @@ -246,12 +280,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } val offersPriorityList = new ArrayBuffer[Int]( - hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) - // First host local, then rack, then others + hyperLocalOffers.size + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + + // First hyper local, then host local, then rack, then others + + // numHostLocalOffers contains count of both hyper local and host offers. val numHostLocalOffers = { + val hyperLocalPriorityList = ClusterScheduler.prioritizeContainers(hyperLocalOffers) + offersPriorityList ++= hyperLocalPriorityList + val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) offersPriorityList ++= hostLocalPriorityList - hostLocalPriorityList.size + + hyperLocalPriorityList.size + hostLocalPriorityList.size } val numRackLocalOffers = { val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) @@ -477,6 +518,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = { + Utils.checkHost(host) + val retval = hostToAliveHostPorts.get(host) if (retval.isDefined) { return Some(retval.get.toSet) @@ -485,6 +528,13 @@ private[spark] class ClusterScheduler(val sc: SparkContext) None } + def isExecutorAliveOnHostPort(hostPort: String): Boolean = { + // Even if hostPort is a host, it does not matter - it is just a specific check. + // But we do have to ensure that only hostPort get into hostPortsAlive ! + // So no check against Utils.checkHostPort + hostPortsAlive.contains(hostPort) + } + // By default, rack is unknown def getRackForHost(value: String): Option[String] = None diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 27e713e2c4..f5c0058554 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -13,14 +13,18 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer -private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { +private[spark] object TaskLocality extends Enumeration("HYPER_LOCAL", "HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { - val HOST_LOCAL, RACK_LOCAL, ANY = Value + // hyper local is expected to be used ONLY within tasksetmanager for now. + val HYPER_LOCAL, HOST_LOCAL, RACK_LOCAL, ANY = Value type TaskLocality = Value def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + // Must not be the constraint. + assert (constraint != TaskLocality.HYPER_LOCAL) + constraint match { case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL @@ -32,7 +36,11 @@ private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL def parse(str: String): TaskLocality = { // better way to do this ? try { - TaskLocality.withName(str) + val retval = TaskLocality.withName(str) + // Must not specify HYPER_LOCAL ! + assert (retval != TaskLocality.HYPER_LOCAL) + + retval } catch { case nEx: NoSuchElementException => { logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL"); @@ -133,35 +141,55 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } - private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = { - // DEBUG code - _taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations)) - - val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else { - // Expand set to include all 'seen' rack local hosts. - // This works since container allocation/management happens within master - so any rack locality information is updated in msater. - // Best case effort, and maybe sort of kludge for now ... rework it later ? - val hosts = new HashSet[String] - _taskPreferredLocations.foreach(h => { - val rackOpt = scheduler.getRackForHost(h) - if (rackOpt.isDefined) { - val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) - if (hostsOpt.isDefined) { - hosts ++= hostsOpt.get + // Note that it follows the hierarchy. + // if we search for HOST_LOCAL, the output will include HYPER_LOCAL and + // if we search for RACK_LOCAL, it will include HYPER_LOCAL & HOST_LOCAL + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = { + + if (TaskLocality.HYPER_LOCAL == taskLocality) { + // straight forward comparison ! Special case it. + val retval = new HashSet[String]() + scheduler.synchronized { + for (location <- _taskPreferredLocations) { + if (scheduler.isExecutorAliveOnHostPort(location)) { + retval += location } } + } - // Ensure that irrespective of what scheduler says, host is always added ! - hosts += h - }) - - hosts + return retval } - val retval = new ArrayBuffer[String] + val taskPreferredLocations = + if (TaskLocality.HOST_LOCAL == taskLocality) { + _taskPreferredLocations + } else { + assert (TaskLocality.RACK_LOCAL == taskLocality) + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new HashSet[String] scheduler.synchronized { for (prefLocation <- taskPreferredLocations) { - val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation) + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) if (aliveLocationsOpt.isDefined) { retval ++= aliveLocationsOpt.get } @@ -175,29 +203,37 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe private def addPendingTask(index: Int) { // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched) - val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + val hyperLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HYPER_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) if (rackLocalLocations.size == 0) { // Current impl ensures this. + assert (hyperLocalLocations.size == 0) assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - // host locality - for (hostPort <- hostLocalLocations) { + // hyper local locality + for (hostPort <- hyperLocalLocations) { // DEBUG Code Utils.checkHostPort(hostPort) val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) hostPortList += index + } + + // host locality (includes hyper local) + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) val host = Utils.parseHostPort(hostPort)._1 val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) hostList += index } - // rack locality + // rack locality (includes hyper local and host local) for (rackLocalHostPort <- rackLocalLocations) { // DEBUG Code Utils.checkHostPort(rackLocalHostPort) @@ -233,6 +269,11 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) } + // Number of pending tasks for a given host Port (which would be hyper local) + def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + // Number of pending tasks for a given host (which would be data local) def numPendingTasksForHost(hostPort: String): Int = { getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) @@ -270,7 +311,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe if (speculatableTasks.size > 0) { val localTask = speculatableTasks.find { index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched) + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) val attemptLocs = taskAttempts(index).map(_.hostPort) (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) } @@ -284,7 +325,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { val rackTask = speculatableTasks.find { index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) val attemptLocs = taskAttempts(index).map(_.hostPort) locations.contains(hostPort) && !attemptLocs.contains(hostPort) } @@ -311,6 +352,11 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val hyperLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (hyperLocalTask != None) { + return hyperLocalTask + } + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) if (localTask != None) { return localTask @@ -341,30 +387,31 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe return findSpeculativeTask(hostPort, locality) } - // Does a host count as a preferred location for a task? This is true if - // either the task has preferred locations and this host is one, or it has - // no preferred locations (in which we still count the launch as preferred). - private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = { + private def isHyperLocalLocation(task: Task[_], hostPort: String): Boolean = { + Utils.checkHostPort(hostPort) + val locs = task.preferredLocations - // DEBUG code - locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) - if (locs.contains(hostPort) || locs.isEmpty) return true + locs.contains(hostPort) + } + + private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { + val locs = task.preferredLocations + + // If no preference, consider it as host local + if (locs.isEmpty) return true val host = Utils.parseHostPort(hostPort)._1 - locs.contains(host) + locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined } // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). // This is true if either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { val locs = task.preferredLocations - // DEBUG code - locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) - val preferredRacks = new HashSet[String]() for (preferredHost <- locs) { val rack = sched.getRackForHost(preferredHost) @@ -395,8 +442,11 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val task = tasks(index) val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch - val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else - if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY + val taskLocality = + if (isHyperLocalLocation(task, hostPort)) TaskLocality.HYPER_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.HOST_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else + TaskLocality.ANY val prefStr = taskLocality.toString logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, execId, hostPort, prefStr)) @@ -552,15 +602,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe def executorLost(execId: String, hostPort: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + // If some task has preferred locations only on hostname, and there are no more executors there, // put it in the no-prefs list to avoid the wait from delay scheduling - for (index <- getPendingTasksForHostPort(hostPort)) { - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true) + + // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to + // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. + // Note: NOT checking hyper local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no hyper local node for the task) + for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { + // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) if (newLocs.isEmpty) { - assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty) pendingTasksWithNoPrefs += index } } + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.executorId == execId) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 6e861ac734..7a0d6ced3e 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -4,7 +4,7 @@ import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import scala.collection.JavaConversions._ import akka.actor.{ActorSystem, Cancellable, Props} @@ -271,23 +271,12 @@ class BlockManager( } - /** - * Get locations of the block. - */ - def getLocations(blockId: String): Seq[String] = { - val startTimeMs = System.currentTimeMillis - var managers = master.getLocations(blockId) - val locations = managers.map(_.hostPort) - logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) - return locations - } - /** * Get locations of an array of blocks. */ - def getLocations(blockIds: Array[String]): Array[Seq[String]] = { + def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray + val locations = master.getLocations(blockIds).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -947,6 +936,32 @@ object BlockManager extends Logging { } } } + + def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv): HashMap[String, List[String]] = { + val blockManager = env.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locationBlockIds = blockManager.getLocationBlockIds(blockIds) + + // Convert from block master locations to executor locations (we need that for task scheduling) + val executorLocations = new HashMap[String, List[String]]() + for (i <- 0 until blockIds.length) { + val blockId = blockIds(i) + val blockLocations = locationBlockIds(i) + + val executors = new HashSet[String]() + + for (bkLocation <- blockLocations) { + val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) + executors += executorHostPort + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } + + executorLocations.put(blockId, executors.toSeq.toList) + } + + executorLocations + } + } class BlockFetcherIterator( diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 3abc584b6a..875975ca43 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -80,12 +80,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { } test("remote fetch") { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTracker() masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() slaveTracker.trackerActor = slaveSystem.actorFor( "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index c0f8986de8..16554eac6e 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -385,12 +385,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(results === Map(0 -> 42)) } - /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */ private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { assert(locations.size === taskSet.tasks.size) for ((expectLocs, taskLocs) <- taskSet.tasks.map(_.preferredLocations).zip(locations)) { - assert(expectLocs === taskLocs) + assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs) } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index b8c0f6fb76..3fc2825255 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -41,6 +41,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() + // Set some value ... + System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111) } after { -- cgit v1.2.3 From 27764a00f40391b94fa05abb11484c442607f6f7 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Wed, 1 May 2013 20:56:05 +0530 Subject: Fix some npe introduced accidentally --- .../main/scala/spark/scheduler/DAGScheduler.scala | 2 +- .../main/scala/spark/storage/BlockManager.scala | 30 ++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8072c60bb7..b18248d2b5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -117,7 +117,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env) + val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil)) } cacheLocs(rdd.id) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a0d6ced3e..040082e600 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -937,10 +937,16 @@ object BlockManager extends Logging { } } - def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv): HashMap[String, List[String]] = { - val blockManager = env.blockManager - /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ - val locationBlockIds = blockManager.getLocationBlockIds(blockIds) + def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = { + // env == null and blockManagerMaster != null is used in tests + assert (env != null || blockManagerMaster != null) + val locationBlockIds: Seq[Seq[BlockManagerId]] = + if (env != null) { + val blockManager = env.blockManager + blockManager.getLocationBlockIds(blockIds) + } else { + blockManagerMaster.getLocations(blockIds) + } // Convert from block master locations to executor locations (we need that for task scheduling) val executorLocations = new HashMap[String, List[String]]() @@ -950,10 +956,18 @@ object BlockManager extends Logging { val executors = new HashSet[String]() - for (bkLocation <- blockLocations) { - val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) - executors += executorHostPort - // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + if (env != null) { + for (bkLocation <- blockLocations) { + val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host) + executors += executorHostPort + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } + } else { + // Typically while testing, etc - revert to simply using host. + for (bkLocation <- blockLocations) { + executors += bkLocation.host + // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort) + } } executorLocations.put(blockId, executors.toSeq.toList) -- cgit v1.2.3 From 848156273178bed5763bcbc91baa788bd4a57f6e Mon Sep 17 00:00:00 2001 From: harshars Date: Mon, 25 Mar 2013 20:09:07 -0700 Subject: Merged Ram's commit on removing RDDs. Conflicts: core/src/main/scala/spark/SparkContext.scala --- core/src/main/scala/spark/SparkContext.scala | 62 +++++++++++++++--------- core/src/test/scala/spark/DistributedSuite.scala | 12 +++++ core/src/test/scala/spark/RDDSuite.scala | 7 +++ 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5f5ec0b0f4..8bee1d65a2 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -1,47 +1,50 @@ package spark import java.io._ -import java.util.concurrent.atomic.AtomicInteger import java.net.URI +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ +import scala.collection.mutable.{ConcurrentMap, HashMap} + +import akka.actor.Actor._ -import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.SequenceFileInputFormat -import org.apache.hadoop.io.Writable -import org.apache.hadoop.io.IntWritable -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.FloatWritable -import org.apache.hadoop.io.DoubleWritable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.ArrayWritable import org.apache.hadoop.io.BooleanWritable import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.ArrayWritable +import org.apache.hadoop.io.DoubleWritable +import org.apache.hadoop.io.FloatWritable +import org.apache.hadoop.io.IntWritable +import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.SequenceFileInputFormat import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} + import org.apache.mesos.MesosNativeLibrary -import spark.deploy.{SparkHadoopUtil, LocalSparkCluster} -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult +import spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import spark.partial.{ApproximateEvaluator, PartialResult} import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} -import spark.scheduler._ +import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler} +import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler} import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import spark.storage.BlockManagerUI +import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} -import spark.storage.{StorageStatus, StorageUtils, RDDInfo} + /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -97,7 +100,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() + private[spark] val persistentRdds: ConcurrentMap[Int, RDD[_]] = new ConcurrentHashMap[Int, RDD[_]]() private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) @@ -520,6 +523,21 @@ class SparkContext( env.blockManager.master.getStorageStatus } + def removeRDD(id: Int): Unit = { + val storageStatusList = getExecutorStorageStatus + val groupedRddBlocks = storageStatusList.flatMap(_.blocks).toMap + logInfo("RDD to remove: " + id) + groupedRddBlocks.foreach(x => { + val k = x._1.substring(0,x._1.lastIndexOf('_')) + val rdd_id = "rdd_" + id + logInfo("RDD to check: " + rdd_id) + if(k.equals(rdd_id)) { + env.blockManager.master.removeBlock(x._1) + } + }) + persistentRdds.remove(id) + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. @@ -743,7 +761,7 @@ class SparkContext( /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ private[spark] def cleanup(cleanupTime: Long) { - persistentRdds.clearOldValues(cleanupTime) + // do nothing. this needs to be removed. } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index c9b4707def..c7f6ab3133 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -252,6 +252,18 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(data2.count === 2) } } + + test("remove RDDs cleanly") { + DistributedSuite.amMaster = true + sc = new SparkContext("local-cluster[3,1,512]", "test") + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + data.count + sc.removeRDD(data.id) + assert(sc.persistentRdds.isEmpty == true) + assert(sc.getRDDStorageInfo.isEmpty == true) + + } } object DistributedSuite { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7fbdd44340..88b7ab9f52 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -100,6 +100,13 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(rdd.collect().toList === List(1, 2, 3, 4)) } + test("remove RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1,2,3,4), 2).cache() + sc.removeRDD(rdd.id) + assert(sc.persistentRdds.empty == true) + } + test("caching with failures") { sc = new SparkContext("local", "test") val onlySplit = new Partition { override def index: Int = 0 } -- cgit v1.2.3 From 3227ec8edde05cff27c1f9de8861d18b3cda1aae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 May 2013 16:07:44 -0700 Subject: Cleaned up Ram's code. Moved SparkContext.remove to RDD.unpersist. Also updated unit tests to make sure they are properly testing for concurrency. --- core/src/main/scala/spark/RDD.scala | 17 ++++++++++++ core/src/main/scala/spark/SparkContext.scala | 25 ++++-------------- .../main/scala/spark/storage/BlockManagerUI.scala | 4 +-- core/src/test/scala/spark/DistributedSuite.scala | 30 ++++++++++++++++------ core/src/test/scala/spark/RDDSuite.scala | 27 +++++++++++++++---- 5 files changed, 68 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 09e52ebf3e..c77f9915c0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -140,6 +140,23 @@ abstract class RDD[T: ClassManifest]( /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() + /** Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. */ + def unpersist(): RDD[T] = { + logInfo("Removing RDD " + id + " from persistence list") + val rddBlockPrefix = "rdd_" + id + "_" + // Get the list of blocks in block manager, and remove ones that are part of this RDD. + // The runtime complexity is linear to the number of blocks persisted in the cluster. + // It could be expensive if the cluster is large and has a lot of blocks persisted. + sc.getExecutorStorageStatus().flatMap(_.blocks).foreach { case(blockId, status) => + if (blockId.startsWith(rddBlockPrefix)) { + sc.env.blockManager.master.removeBlock(blockId) + } + } + sc.persistentRdds.remove(id) + storageLevel = StorageLevel.NONE + this + } + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8bee1d65a2..b686c595b8 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -100,7 +100,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds: ConcurrentMap[Int, RDD[_]] = new ConcurrentHashMap[Int, RDD[_]]() + private[spark] val persistentRdds: ConcurrentMap[Int, RDD[_]] = new ConcurrentHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) @@ -508,36 +508,21 @@ class SparkContext( * Return information about what RDDs are cached, if they are in mem or on disk, how much space * they take, etc. */ - def getRDDStorageInfo : Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) + def getRDDStorageInfo(): Array[RDDInfo] = { + StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus(), this) } - def getStageInfo: Map[Stage,StageInfo] = { + def getStageInfo(): Map[Stage,StageInfo] = { dagScheduler.stageToInfos } /** * Return information about blocks stored in all of the slaves */ - def getExecutorStorageStatus : Array[StorageStatus] = { + def getExecutorStorageStatus(): Array[StorageStatus] = { env.blockManager.master.getStorageStatus } - def removeRDD(id: Int): Unit = { - val storageStatusList = getExecutorStorageStatus - val groupedRddBlocks = storageStatusList.flatMap(_.blocks).toMap - logInfo("RDD to remove: " + id) - groupedRddBlocks.foreach(x => { - val k = x._1.substring(0,x._1.lastIndexOf('_')) - val rdd_id = "rdd_" + id - logInfo("RDD to check: " + rdd_id) - if(k.equals(rdd_id)) { - env.blockManager.master.removeBlock(x._1) - } - }) - persistentRdds.remove(id) - } - /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 07da572044..c9e4519efe 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -45,7 +45,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val storageStatusList = sc.getExecutorStorageStatus + val storageStatusList = sc.getExecutorStorageStatus() // Calculate macro-level statistics val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) @@ -60,7 +60,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, parameter("id") { id => completeWith { val prefix = "rdd_" + id.toString - val storageStatusList = sc.getExecutorStorageStatus + val storageStatusList = sc.getExecutorStorageStatus() val filteredStorageStatusList = StorageUtils. filterStorageStatusByPrefix(storageStatusList, prefix) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index c7f6ab3133..ab3e197035 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -3,8 +3,10 @@ package spark import network.ConnectionManagerId import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers +import org.scalatest.time.{Span, Millis} import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ @@ -252,24 +254,36 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(data2.count === 2) } } - - test("remove RDDs cleanly") { + + test("unpersist RDDs") { DistributedSuite.amMaster = true sc = new SparkContext("local-cluster[3,1,512]", "test") val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count - sc.removeRDD(data.id) + assert(sc.persistentRdds.isEmpty == false) + data.unpersist() assert(sc.persistentRdds.isEmpty == true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case e: Exception => + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } assert(sc.getRDDStorageInfo.isEmpty == true) - } } object DistributedSuite { // Indicates whether this JVM is marked for failure. var mark = false - + // Set by test to remember if we are in the driver program so we can assert // that we are not. var amMaster = false @@ -286,9 +300,9 @@ object DistributedSuite { // Act like an identity function, but if mark was set to true previously, fail, // crashing the entire JVM. def failOnMarkedIdentity(item: Boolean): Boolean = { - if (mark) { + if (mark) { System.exit(42) - } + } item - } + } } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 88b7ab9f52..cee6312572 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,6 +2,8 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} import spark.SparkContext._ import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD} @@ -100,11 +102,26 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(rdd.collect().toList === List(1, 2, 3, 4)) } - test("remove RDD") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1,2,3,4), 2).cache() - sc.removeRDD(rdd.id) - assert(sc.persistentRdds.empty == true) + test("unpersist RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + rdd.count + assert(sc.persistentRdds.isEmpty == false) + rdd.unpersist() + assert(sc.persistentRdds.isEmpty == true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case e: Exception => + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + assert(sc.getRDDStorageInfo.isEmpty == true) } test("caching with failures") { -- cgit v1.2.3 From 34637b97ec7ebdd356653324f15345b00b3a2ac2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 May 2013 16:12:37 -0700 Subject: Added SparkContext.cleanup back. Not sure why it was removed before ... --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b686c595b8..401e55d615 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -746,7 +746,7 @@ class SparkContext( /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ private[spark] def cleanup(cleanupTime: Long) { - // do nothing. this needs to be removed. + persistentRdds.clearOldValues(cleanupTime) } } -- cgit v1.2.3 From 204eb32e14e8fce5e4b4cf602375ae9b4ed136c9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 May 2013 16:14:58 -0700 Subject: Changed the type of the persistentRdds hashmap back to TimeStampedHashMap. --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 401e55d615..d7d450d958 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -100,7 +100,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds: ConcurrentMap[Int, RDD[_]] = new ConcurrentHashMap[Int, RDD[_]] + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) -- cgit v1.2.3 From 207afe4088219a0c7350b3f80eb60e86c97e140f Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 18 Apr 2013 12:08:11 -0700 Subject: Remove spark-repl's extraneous dependency on spark-streaming --- project/SparkBuild.scala | 2 +- repl/pom.xml | 14 -------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f2410085d8..190d723435 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -29,7 +29,7 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) - lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming) + lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) diff --git a/repl/pom.xml b/repl/pom.xml index 038da5d988..92a2020b48 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -96,13 +96,6 @@ hadoop1 runtime - - org.spark-project - spark-streaming - ${project.version} - hadoop1 - runtime - org.apache.hadoop hadoop-core @@ -147,13 +140,6 @@ hadoop2 runtime - - org.spark-project - spark-streaming - ${project.version} - hadoop2 - runtime - org.apache.hadoop hadoop-core -- cgit v1.2.3 From 609a817f52d8db05711c0d4529dd1448ed8c4fe0 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 2 May 2013 06:44:33 +0530 Subject: Integrate review comments on pull request --- core/src/main/scala/spark/SparkEnv.scala | 9 ++-- core/src/main/scala/spark/Utils.scala | 4 +- .../main/scala/spark/scheduler/ResultTask.scala | 1 - .../scala/spark/scheduler/ShuffleMapTask.scala | 1 - .../spark/scheduler/cluster/ClusterScheduler.scala | 32 +++++++------- .../spark/scheduler/cluster/TaskSetManager.scala | 50 +++++++++++----------- 6 files changed, 47 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5b4a464010..2ee25e547d 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -33,8 +33,7 @@ class SparkEnv ( // To be set only as part of initialization of SparkContext. // (executorId, defaultHostPort) => executorHostPort // If executorId is NOT found, return defaultHostPort - var executorIdToHostPort: (String, String) => String - ) { + var executorIdToHostPort: Option[(String, String) => String]) { def stop() { httpFileServer.stop() @@ -52,12 +51,12 @@ class SparkEnv ( def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { val env = SparkEnv.get - if (env.executorIdToHostPort == null) { + if (env.executorIdToHostPort.isEmpty) { // default to using host, not host port. Relevant to non cluster modes. return defaultHostPort } - env.executorIdToHostPort(executorId, defaultHostPort) + env.executorIdToHostPort.get(executorId, defaultHostPort) } } @@ -178,6 +177,6 @@ object SparkEnv extends Logging { connectionManager, httpFileServer, sparkFilesDir, - null) + None) } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 279daf04ed..0e348f8189 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -335,7 +335,6 @@ private object Utils extends Logging { retval } - /* // Used by DEBUG code : remove when all testing done private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") def checkHost(host: String, message: String = "") { @@ -364,8 +363,8 @@ private object Utils extends Logging { // temp code for debug System.exit(-1) } - */ +/* // Once testing is complete in various modes, replace with this ? def checkHost(host: String, message: String = "") {} def checkHostPort(hostPort: String, message: String = "") {} @@ -374,6 +373,7 @@ private object Utils extends Logging { def logErrorWithStack(msg: String) { try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } } +*/ def getUserNameFromEnvironment(): String = { SparkHadoopUtil.getUserNameFromEnvironment diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index c43cbe5ed4..83166bce22 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -70,7 +70,6 @@ private[spark] class ResultTask[T, U]( rdd.partitions(partition) } - // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 0b848af2f3..4b36e71c32 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -84,7 +84,6 @@ private[spark] class ShuffleMapTask( protected def this() = this(0, null, null, 0, null) - // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 3c72ce4206..49fc449e86 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -73,7 +73,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val activeExecutorIds = new HashSet[String] // TODO: We might want to remove this and merge it with execId datastructures - but later. - // Which hosts in the cluster are alive (contains hostPort's) - used for hyper local and local task locality. + // Which hosts in the cluster are alive (contains hostPort's) - used for instance local and node local task locality. private val hostPortsAlive = new HashSet[String] private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] @@ -109,7 +109,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler // Will that be a design violation ? - SparkEnv.get.executorIdToHostPort = executorToHostPort + SparkEnv.get.executorIdToHostPort = Some(executorToHostPort) } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -240,7 +240,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { // Split offers based on host local, rack local and off-rack tasks. - val hyperLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val instanceLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val otherOffers = new HashMap[String, ArrayBuffer[Int]]() @@ -250,16 +250,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // DEBUG code Utils.checkHostPort(hostPort) - val numHyperLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) - if (numHyperLocalTasks > 0){ - val list = hyperLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) - for (j <- 0 until numHyperLocalTasks) list += i + val numInstanceLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) + if (numInstanceLocalTasks > 0){ + val list = instanceLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) + for (j <- 0 until numInstanceLocalTasks) list += i } val host = Utils.parseHostPort(hostPort)._1 val numHostLocalTasks = math.max(0, - // Remove hyper local tasks (which are also host local btw !) from this - math.min(manager.numPendingTasksForHost(hostPort) - numHyperLocalTasks, hostToAvailableCpus(host))) + // Remove instance local tasks (which are also host local btw !) from this + math.min(manager.numPendingTasksForHost(hostPort) - numInstanceLocalTasks, hostToAvailableCpus(host))) if (numHostLocalTasks > 0){ val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) for (j <- 0 until numHostLocalTasks) list += i @@ -267,7 +267,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val numRackLocalTasks = math.max(0, // Remove host local tasks (which are also rack local btw !) from this - math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHyperLocalTasks - numHostLocalTasks, hostToAvailableCpus(host))) + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numInstanceLocalTasks - numHostLocalTasks, hostToAvailableCpus(host))) if (numRackLocalTasks > 0){ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) for (j <- 0 until numRackLocalTasks) list += i @@ -280,19 +280,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } val offersPriorityList = new ArrayBuffer[Int]( - hyperLocalOffers.size + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + instanceLocalOffers.size + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) - // First hyper local, then host local, then rack, then others + // First instance local, then host local, then rack, then others - // numHostLocalOffers contains count of both hyper local and host offers. + // numHostLocalOffers contains count of both instance local and host offers. val numHostLocalOffers = { - val hyperLocalPriorityList = ClusterScheduler.prioritizeContainers(hyperLocalOffers) - offersPriorityList ++= hyperLocalPriorityList + val instanceLocalPriorityList = ClusterScheduler.prioritizeContainers(instanceLocalOffers) + offersPriorityList ++= instanceLocalPriorityList val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) offersPriorityList ++= hostLocalPriorityList - hyperLocalPriorityList.size + hostLocalPriorityList.size + instanceLocalPriorityList.size + hostLocalPriorityList.size } val numRackLocalOffers = { val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index f5c0058554..5f3faaa5c3 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -13,17 +13,17 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer -private[spark] object TaskLocality extends Enumeration("HYPER_LOCAL", "HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { +private[spark] object TaskLocality extends Enumeration("INSTANCE_LOCAL", "HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { - // hyper local is expected to be used ONLY within tasksetmanager for now. - val HYPER_LOCAL, HOST_LOCAL, RACK_LOCAL, ANY = Value + // instance local is expected to be used ONLY within tasksetmanager for now. + val INSTANCE_LOCAL, HOST_LOCAL, RACK_LOCAL, ANY = Value type TaskLocality = Value def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { // Must not be the constraint. - assert (constraint != TaskLocality.HYPER_LOCAL) + assert (constraint != TaskLocality.INSTANCE_LOCAL) constraint match { case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL @@ -37,8 +37,8 @@ private[spark] object TaskLocality extends Enumeration("HYPER_LOCAL", "HOST_LOCA // better way to do this ? try { val retval = TaskLocality.withName(str) - // Must not specify HYPER_LOCAL ! - assert (retval != TaskLocality.HYPER_LOCAL) + // Must not specify INSTANCE_LOCAL ! + assert (retval != TaskLocality.INSTANCE_LOCAL) retval } catch { @@ -84,7 +84,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node (hyper local to container). These collections are actually + // List of pending tasks for each node (instance local to container). These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put @@ -142,12 +142,12 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } // Note that it follows the hierarchy. - // if we search for HOST_LOCAL, the output will include HYPER_LOCAL and - // if we search for RACK_LOCAL, it will include HYPER_LOCAL & HOST_LOCAL + // if we search for HOST_LOCAL, the output will include INSTANCE_LOCAL and + // if we search for RACK_LOCAL, it will include INSTANCE_LOCAL & HOST_LOCAL private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - if (TaskLocality.HYPER_LOCAL == taskLocality) { + if (TaskLocality.INSTANCE_LOCAL == taskLocality) { // straight forward comparison ! Special case it. val retval = new HashSet[String]() scheduler.synchronized { @@ -203,19 +203,19 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe private def addPendingTask(index: Int) { // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val hyperLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HYPER_LOCAL) + val instanceLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.INSTANCE_LOCAL) val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) if (rackLocalLocations.size == 0) { // Current impl ensures this. - assert (hyperLocalLocations.size == 0) + assert (instanceLocalLocations.size == 0) assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - // hyper local locality - for (hostPort <- hyperLocalLocations) { + // instance local locality + for (hostPort <- instanceLocalLocations) { // DEBUG Code Utils.checkHostPort(hostPort) @@ -223,7 +223,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe hostPortList += index } - // host locality (includes hyper local) + // host locality (includes instance local) for (hostPort <- hostLocalLocations) { // DEBUG Code Utils.checkHostPort(hostPort) @@ -233,7 +233,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe hostList += index } - // rack locality (includes hyper local and host local) + // rack locality (includes instance local and host local) for (rackLocalHostPort <- rackLocalLocations) { // DEBUG Code Utils.checkHostPort(rackLocalHostPort) @@ -247,7 +247,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe allPendingTasks += index } - // Return the pending tasks list for a given host port (hyper local), or an empty list if + // Return the pending tasks list for a given host port (instance local), or an empty list if // there is no map entry for that host private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { // DEBUG Code @@ -269,7 +269,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) } - // Number of pending tasks for a given host Port (which would be hyper local) + // Number of pending tasks for a given host Port (which would be instance local) def numPendingTasksForHostPort(hostPort: String): Int = { getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) } @@ -352,9 +352,9 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - val hyperLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) - if (hyperLocalTask != None) { - return hyperLocalTask + val instanceLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (instanceLocalTask != None) { + return instanceLocalTask } val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) @@ -387,7 +387,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe return findSpeculativeTask(hostPort, locality) } - private def isHyperLocalLocation(task: Task[_], hostPort: String): Boolean = { + private def isInstanceLocalLocation(task: Task[_], hostPort: String): Boolean = { Utils.checkHostPort(hostPort) val locs = task.preferredLocations @@ -443,7 +443,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val taskLocality = - if (isHyperLocalLocation(task, hostPort)) TaskLocality.HYPER_LOCAL else + if (isInstanceLocalLocation(task, hostPort)) TaskLocality.INSTANCE_LOCAL else if (isHostLocalLocation(task, hostPort)) TaskLocality.HOST_LOCAL else if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY @@ -608,8 +608,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking hyper local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no hyper local node for the task) + // Note: NOT checking instance local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no instance local node for the task) for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) -- cgit v1.2.3 From c047f0e3adae59d7e388a1d42d940c3cd5714f82 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 26 Apr 2013 13:28:21 +0800 Subject: filter out Spark streaming block RDD and sort RDDInfo with id --- .../main/scala/spark/storage/StorageUtils.scala | 33 ++++++++++++++-------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index dec47a9d41..8f52168c24 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -4,9 +4,9 @@ import spark.{Utils, SparkContext} import BlockManagerMasterActor.BlockStatus private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, blocks: Map[String, BlockStatus]) { - + def memUsed(blockPrefix: String = "") = { blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). reduceOption(_+_).getOrElse(0l) @@ -22,35 +22,40 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) { + numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) + extends Ordered[RDDInfo] { override def toString = { import Utils.memoryBytesToString "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize)) } + + override def compare(that: RDDInfo) = { + this.id - that.id + } } /* Helper methods for storage-related objects */ private[spark] object StorageUtils { - /* Given the current storage status of the BlockManager, returns information for each RDD */ - def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], + /* Given the current storage status of the BlockManager, returns information for each RDD */ + def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) } - /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.groupBy { case(k, v) => + val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => k.substring(0,k.lastIndexOf('_')) }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - groupedRddBlocks.map { case(rddKey, rddBlocks) => + val rddInfos = groupedRddBlocks.map { case(rddKey, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) @@ -65,10 +70,14 @@ object StorageUtils { RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize) }.toArray + + scala.util.Sorting.quickSort(rddInfos) + + rddInfos } - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], prefix: String) : Array[StorageStatus] = { storageStatusList.map { status => -- cgit v1.2.3 From 1b5aaeadc72ad5197c00897c41f670ea241d0235 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 2 May 2013 07:30:06 +0530 Subject: Integrate review comments 2 --- .../spark/scheduler/cluster/ClusterScheduler.scala | 78 +++++++++++----------- .../spark/scheduler/cluster/TaskSetManager.scala | 74 ++++++++++---------- .../spark/scheduler/local/LocalScheduler.scala | 2 +- 3 files changed, 77 insertions(+), 77 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 49fc449e86..cf4483f144 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -32,28 +32,28 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong /* - This property controls how aggressive we should be to modulate waiting for host local task scheduling. - To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before + This property controls how aggressive we should be to modulate waiting for node local task scheduling. + To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : - host-local, rack-local and then others - But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before + node-local, rack-local and then others + But once all available node local (and no pref) tasks are scheduled, instead of waiting for 3 sec before scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. TODO: rename property ? The value is one of - - HOST_LOCAL (default, no change w.r.t current behavior), + - NODE_LOCAL (default, no change w.r.t current behavior), - RACK_LOCAL and - ANY Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether - it is left at default HOST_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. + it is left at default NODE_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact. Also, it brings down the variance in running time drastically. */ - val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL")) + val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL")) val activeTaskSets = new HashMap[String, TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -73,7 +73,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val activeExecutorIds = new HashSet[String] // TODO: We might want to remove this and merge it with execId datastructures - but later. - // Which hosts in the cluster are alive (contains hostPort's) - used for instance local and node local task locality. + // Which hosts in the cluster are alive (contains hostPort's) - used for process local and node local task locality. private val hostPortsAlive = new HashSet[String] private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] @@ -217,9 +217,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - // merge availableCpus into hostToAvailableCpus block ? + // merge availableCpus into nodeToAvailableCpus block ? val availableCpus = offers.map(o => o.cores).toArray - val hostToAvailableCpus = { + val nodeToAvailableCpus = { val map = new HashMap[String, Int]() for (offer <- offers) { val hostPort = offer.hostPort @@ -239,9 +239,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { - // Split offers based on host local, rack local and off-rack tasks. - val instanceLocalOffers = new HashMap[String, ArrayBuffer[Int]]() - val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + // Split offers based on node local, rack local and off-rack tasks. + val processLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val nodeLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() val otherOffers = new HashMap[String, ArrayBuffer[Int]]() @@ -250,29 +250,29 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // DEBUG code Utils.checkHostPort(hostPort) - val numInstanceLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) - if (numInstanceLocalTasks > 0){ - val list = instanceLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) - for (j <- 0 until numInstanceLocalTasks) list += i + val numProcessLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i))) + if (numProcessLocalTasks > 0){ + val list = processLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int]) + for (j <- 0 until numProcessLocalTasks) list += i } val host = Utils.parseHostPort(hostPort)._1 - val numHostLocalTasks = math.max(0, - // Remove instance local tasks (which are also host local btw !) from this - math.min(manager.numPendingTasksForHost(hostPort) - numInstanceLocalTasks, hostToAvailableCpus(host))) - if (numHostLocalTasks > 0){ - val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) - for (j <- 0 until numHostLocalTasks) list += i + val numNodeLocalTasks = math.max(0, + // Remove process local tasks (which are also host local btw !) from this + math.min(manager.numPendingTasksForHost(hostPort) - numProcessLocalTasks, nodeToAvailableCpus(host))) + if (numNodeLocalTasks > 0){ + val list = nodeLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numNodeLocalTasks) list += i } val numRackLocalTasks = math.max(0, - // Remove host local tasks (which are also rack local btw !) from this - math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numInstanceLocalTasks - numHostLocalTasks, hostToAvailableCpus(host))) + // Remove node local tasks (which are also rack local btw !) from this + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numProcessLocalTasks - numNodeLocalTasks, nodeToAvailableCpus(host))) if (numRackLocalTasks > 0){ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) for (j <- 0 until numRackLocalTasks) list += i } - if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){ + if (numNodeLocalTasks <= 0 && numRackLocalTasks <= 0){ // add to others list - spread even this across cluster. val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) list += i @@ -280,19 +280,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } val offersPriorityList = new ArrayBuffer[Int]( - instanceLocalOffers.size + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + processLocalOffers.size + nodeLocalOffers.size + rackLocalOffers.size + otherOffers.size) - // First instance local, then host local, then rack, then others + // First process local, then host local, then rack, then others - // numHostLocalOffers contains count of both instance local and host offers. - val numHostLocalOffers = { - val instanceLocalPriorityList = ClusterScheduler.prioritizeContainers(instanceLocalOffers) - offersPriorityList ++= instanceLocalPriorityList + // numNodeLocalOffers contains count of both process local and host offers. + val numNodeLocalOffers = { + val processLocalPriorityList = ClusterScheduler.prioritizeContainers(processLocalOffers) + offersPriorityList ++= processLocalPriorityList - val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) - offersPriorityList ++= hostLocalPriorityList + val nodeLocalPriorityList = ClusterScheduler.prioritizeContainers(nodeLocalOffers) + offersPriorityList ++= nodeLocalPriorityList - instanceLocalPriorityList.size + hostLocalPriorityList.size + processLocalPriorityList.size + nodeLocalPriorityList.size } val numRackLocalOffers = { val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) @@ -303,8 +303,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var lastLoop = false val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match { - case TaskLocality.HOST_LOCAL => numHostLocalOffers - case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers + case TaskLocality.NODE_LOCAL => numNodeLocalOffers + case TaskLocality.RACK_LOCAL => numRackLocalOffers + numNodeLocalOffers case TaskLocality.ANY => offersPriorityList.size } @@ -343,8 +343,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // prevent more looping launchedTask = false } else if (!lastLoop && !launchedTask) { - // Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL - if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) { + // Do this only if TASK_SCHEDULING_AGGRESSION != NODE_LOCAL + if (TASK_SCHEDULING_AGGRESSION != TaskLocality.NODE_LOCAL) { // fudge launchedTask to ensure we loop once more launchedTask = true // dont loop anymore diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5f3faaa5c3..ff4790e4cb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -13,21 +13,21 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer -private[spark] object TaskLocality extends Enumeration("INSTANCE_LOCAL", "HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { +private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { - // instance local is expected to be used ONLY within tasksetmanager for now. - val INSTANCE_LOCAL, HOST_LOCAL, RACK_LOCAL, ANY = Value + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value type TaskLocality = Value def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { // Must not be the constraint. - assert (constraint != TaskLocality.INSTANCE_LOCAL) + assert (constraint != TaskLocality.PROCESS_LOCAL) constraint match { - case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL - case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL + case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL // For anything else, allow case _ => true } @@ -37,15 +37,15 @@ private[spark] object TaskLocality extends Enumeration("INSTANCE_LOCAL", "HOST_L // better way to do this ? try { val retval = TaskLocality.withName(str) - // Must not specify INSTANCE_LOCAL ! - assert (retval != TaskLocality.INSTANCE_LOCAL) + // Must not specify PROCESS_LOCAL ! + assert (retval != TaskLocality.PROCESS_LOCAL) retval } catch { case nEx: NoSuchElementException => { - logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL"); + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); // default to preserve earlier behavior - HOST_LOCAL + NODE_LOCAL } } } @@ -84,7 +84,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node (instance local to container). These collections are actually + // List of pending tasks for each node (process local to container). These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put @@ -142,12 +142,12 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } // Note that it follows the hierarchy. - // if we search for HOST_LOCAL, the output will include INSTANCE_LOCAL and - // if we search for RACK_LOCAL, it will include INSTANCE_LOCAL & HOST_LOCAL + // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and + // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - if (TaskLocality.INSTANCE_LOCAL == taskLocality) { + if (TaskLocality.PROCESS_LOCAL == taskLocality) { // straight forward comparison ! Special case it. val retval = new HashSet[String]() scheduler.synchronized { @@ -162,7 +162,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } val taskPreferredLocations = - if (TaskLocality.HOST_LOCAL == taskLocality) { + if (TaskLocality.NODE_LOCAL == taskLocality) { _taskPreferredLocations } else { assert (TaskLocality.RACK_LOCAL == taskLocality) @@ -203,19 +203,19 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe private def addPendingTask(index: Int) { // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val instanceLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.INSTANCE_LOCAL) - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) + val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) if (rackLocalLocations.size == 0) { // Current impl ensures this. - assert (instanceLocalLocations.size == 0) + assert (processLocalLocations.size == 0) assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - // instance local locality - for (hostPort <- instanceLocalLocations) { + // process local locality + for (hostPort <- processLocalLocations) { // DEBUG Code Utils.checkHostPort(hostPort) @@ -223,7 +223,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe hostPortList += index } - // host locality (includes instance local) + // host locality (includes process local) for (hostPort <- hostLocalLocations) { // DEBUG Code Utils.checkHostPort(hostPort) @@ -233,7 +233,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe hostList += index } - // rack locality (includes instance local and host local) + // rack locality (includes process local and host local) for (rackLocalHostPort <- rackLocalLocations) { // DEBUG Code Utils.checkHostPort(rackLocalHostPort) @@ -247,7 +247,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe allPendingTasks += index } - // Return the pending tasks list for a given host port (instance local), or an empty list if + // Return the pending tasks list for a given host port (process local), or an empty list if // there is no map entry for that host private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { // DEBUG Code @@ -269,7 +269,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) } - // Number of pending tasks for a given host Port (which would be instance local) + // Number of pending tasks for a given host Port (which would be process local) def numPendingTasksForHostPort(hostPort: String): Int = { getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) } @@ -305,13 +305,13 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // task must have a preference for this host/rack/no preferred locations at all. private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL)) + assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set if (speculatableTasks.size > 0) { val localTask = speculatableTasks.find { index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) val attemptLocs = taskAttempts(index).map(_.hostPort) (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) } @@ -352,9 +352,9 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - val instanceLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) - if (instanceLocalTask != None) { - return instanceLocalTask + val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (processLocalTask != None) { + return processLocalTask } val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) @@ -387,7 +387,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe return findSpeculativeTask(hostPort, locality) } - private def isInstanceLocalLocation(task: Task[_], hostPort: String): Boolean = { + private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { Utils.checkHostPort(hostPort) val locs = task.preferredLocations @@ -433,7 +433,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val locality = if (overrideLocality != null) overrideLocality else { // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY } findTask(hostPort, locality) match { @@ -443,8 +443,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val taskLocality = - if (isInstanceLocalLocation(task, hostPort)) TaskLocality.INSTANCE_LOCAL else - if (isHostLocalLocation(task, hostPort)) TaskLocality.HOST_LOCAL else + if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY val prefStr = taskLocality.toString @@ -456,7 +456,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) - if (TaskLocality.HOST_LOCAL == taskLocality) { + if (TaskLocality.NODE_LOCAL == taskLocality) { lastPreferredLaunchTime = time } // Serialize and return the task @@ -608,11 +608,11 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking instance local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no instance local node for the task) + // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no process local node for the task) for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.HOST_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) if (newLocs.isEmpty) { pendingTasksWithNoPrefs += index } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index f060a940a9..53dd6fbe13 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon def runTask(task: Task[_], idInJob: Int, attemptId: Int) { logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL) + val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { -- cgit v1.2.3 From dfde9ce9dde0a151d42f7aecb826b40a4c98b459 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 2 May 2013 07:41:33 +0530 Subject: comment out debug versions of checkHost, etc from Utils - which were used to test --- core/src/main/scala/spark/Utils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 0e348f8189..c1495d5317 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -335,6 +335,7 @@ private object Utils extends Logging { retval } +/* // Used by DEBUG code : remove when all testing done private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") def checkHost(host: String, message: String = "") { @@ -363,8 +364,8 @@ private object Utils extends Logging { // temp code for debug System.exit(-1) } +*/ -/* // Once testing is complete in various modes, replace with this ? def checkHost(host: String, message: String = "") {} def checkHostPort(hostPort: String, message: String = "") {} @@ -373,7 +374,6 @@ private object Utils extends Logging { def logErrorWithStack(msg: String) { try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } } -*/ def getUserNameFromEnvironment(): String = { SparkHadoopUtil.getUserNameFromEnvironment -- cgit v1.2.3 From 98df9d28536f5208530488a316df9401e16490bd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 May 2013 20:17:09 -0700 Subject: Added removeRdd function in BlockManager. --- core/src/main/scala/spark/RDD.scala | 15 ++++--------- core/src/main/scala/spark/SparkContext.scala | 8 +++---- .../scala/spark/storage/BlockManagerMaster.scala | 16 ++++++++++++++ .../main/scala/spark/storage/BlockManagerUI.scala | 4 ++-- .../scala/spark/storage/BlockManagerSuite.scala | 25 ++++++++++++++++++++++ 5 files changed, 51 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index c77f9915c0..fd14ef17f1 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -107,7 +107,7 @@ abstract class RDD[T: ClassManifest]( // ======================================================================= /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() + val id: Int = sc.newRddId() /** A friendly name for this RDD */ var name: String = null @@ -120,7 +120,8 @@ abstract class RDD[T: ClassManifest]( /** * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. */ def persist(newLevel: StorageLevel): RDD[T] = { // TODO: Handle changes of StorageLevel @@ -143,15 +144,7 @@ abstract class RDD[T: ClassManifest]( /** Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. */ def unpersist(): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - val rddBlockPrefix = "rdd_" + id + "_" - // Get the list of blocks in block manager, and remove ones that are part of this RDD. - // The runtime complexity is linear to the number of blocks persisted in the cluster. - // It could be expensive if the cluster is large and has a lot of blocks persisted. - sc.getExecutorStorageStatus().flatMap(_.blocks).foreach { case(blockId, status) => - if (blockId.startsWith(rddBlockPrefix)) { - sc.env.blockManager.master.removeBlock(blockId) - } - } + sc.env.blockManager.master.removeRdd(id) sc.persistentRdds.remove(id) storageLevel = StorageLevel.NONE this diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7d450d958..2ae4ad8659 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -508,18 +508,18 @@ class SparkContext( * Return information about what RDDs are cached, if they are in mem or on disk, how much space * they take, etc. */ - def getRDDStorageInfo(): Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus(), this) + def getRDDStorageInfo: Array[RDDInfo] = { + StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } - def getStageInfo(): Map[Stage,StageInfo] = { + def getStageInfo: Map[Stage,StageInfo] = { dagScheduler.stageToInfos } /** * Return information about blocks stored in all of the slaves */ - def getExecutorStorageStatus(): Array[StorageStatus] = { + def getExecutorStorageStatus: Array[StorageStatus] = { env.blockManager.master.getStorageStatus } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 6fae62d373..ac26c16867 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -15,6 +15,7 @@ import akka.util.duration._ import spark.{Logging, SparkException, Utils} + private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt @@ -87,6 +88,21 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi askDriverWithReply(RemoveBlock(blockId)) } + /** + * Remove all blocks belonging to the given RDD. + */ + def removeRdd(rddId: Int) { + val rddBlockPrefix = "rdd_" + rddId + "_" + // Get the list of blocks in block manager, and remove ones that are part of this RDD. + // The runtime complexity is linear to the number of blocks persisted in the cluster. + // It could be expensive if the cluster is large and has a lot of blocks persisted. + getStorageStatus.flatMap(_.blocks).foreach { case(blockId, status) => + if (blockId.startsWith(rddBlockPrefix)) { + removeBlock(blockId) + } + } + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index c9e4519efe..07da572044 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -45,7 +45,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val storageStatusList = sc.getExecutorStorageStatus() + val storageStatusList = sc.getExecutorStorageStatus // Calculate macro-level statistics val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) @@ -60,7 +60,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, parameter("id") { id => completeWith { val prefix = "rdd_" + id.toString - val storageStatusList = sc.getExecutorStorageStatus() + val storageStatusList = sc.getExecutorStorageStatus val filteredStorageStatusList = StorageUtils. filterStorageStatusByPrefix(storageStatusList, prefix) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 5a11a4483b..9fe0de665c 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -207,6 +207,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } } + test("removing rdd") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + // Putting a1, a2 and a3 in memory. + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + master.removeRdd(0) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("nonrddblock") should not be (None) + master.getLocations("nonrddblock") should have size (1) + } + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) store = new BlockManager("", actorSystem, master, serializer, 2000) -- cgit v1.2.3 From 4a318774088f829fe54c3ef0b5f565a845631b4e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 May 2013 20:31:54 -0700 Subject: Added the unpersist api to JavaRDD. --- core/src/main/scala/spark/api/java/JavaRDD.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index e29f1e5899..eb81ed64cd 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -14,12 +14,18 @@ JavaRDDLike[T, JavaRDD[T]] { /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) - /** + /** * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. + * it is computed. This can only be used to assign a new storage level if the RDD does not + * have a storage level set yet.. */ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + */ + def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + // Transformations (return a new RDD) /** @@ -31,7 +37,7 @@ JavaRDDLike[T, JavaRDD[T]] { * Return a new RDD containing the distinct elements in this RDD. */ def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions)) - + /** * Return a new RDD containing only the elements that satisfy a predicate. */ @@ -54,7 +60,7 @@ JavaRDDLike[T, JavaRDD[T]] { */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) - + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). @@ -63,7 +69,7 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ -- cgit v1.2.3 From c847dd3da29483fede326cb9821b0d33f735137e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 19 Mar 2013 15:08:22 -0700 Subject: Don't accept generated temp directory names that can't be created successfully. --- core/src/main/scala/spark/storage/DiskStore.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c9553d2e0f..215c25132b 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -168,8 +168,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) localDir = new File(rootDir, "spark-local-" + localDirId) if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true + foundLocalDir = localDir.mkdirs() } } catch { case e: Exception => -- cgit v1.2.3 From 11589c39d9f75e9757ba1717c5202f77d30031b2 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 3 May 2013 12:23:30 +0530 Subject: Fix ZippedRDD as part Matei's suggestion --- core/src/main/scala/spark/rdd/ZippedRDD.scala | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index e80250a99b..51573fe68a 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} @@ -49,9 +49,20 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( override def getPreferredLocations(s: Partition): Seq[String] = { val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - // TODO: becomes complicated - intersect on hostPort if available, else fallback to host (removing intersected hostPort's). - // Since I am not very sure about this RDD, leaving it to others to comment better ! - rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2)) + val pref1 = rdd1.preferredLocations(partition1) + val pref2 = rdd2.preferredLocations(partition2) + + // both partitions are instance local. + val instanceLocalLocations = pref1.intersect(pref2) + + // remove locations which are already handled via instanceLocalLocations, and intersect where both partitions are node local. + val nodeLocalPref1 = pref1.filter(loc => ! instanceLocalLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val nodeLocalPref2 = pref2.filter(loc => ! instanceLocalLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val nodeLocalLocations = nodeLocalPref1.intersect(nodeLocalPref2) + + + // Can have mix of instance local (hostPort) and node local (host) locations as preference ! + instanceLocalLocations ++ nodeLocalLocations } override def clearDependencies() { -- cgit v1.2.3 From 2bc895a829caa459e032e12e1d117994dd510a5c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 3 May 2013 01:02:16 -0700 Subject: Updated according to Matei's code review comment. --- core/src/main/scala/spark/ShuffleFetcher.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 10 +++-- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 3 +- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 5 +-- core/src/main/scala/spark/rdd/SubtractedRDD.scala | 4 +- .../scala/spark/scheduler/ShuffleMapTask.scala | 5 +-- .../main/scala/spark/serializer/Serializer.scala | 42 -------------------- .../scala/spark/serializer/SerializerManager.scala | 45 ++++++++++++++++++++++ core/src/main/scala/spark/storage/DiskStore.scala | 36 +++++++++-------- .../scala/spark/storage/ShuffleBlockManager.scala | 34 ++++++++-------- 10 files changed, 98 insertions(+), 88 deletions(-) create mode 100644 core/src/main/scala/spark/serializer/SerializerManager.scala diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index 49addc0c10..9513a00126 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -10,7 +10,7 @@ private[spark] abstract class ShuffleFetcher { * @return An iterator over the elements of the fetched shuffle outputs. */ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, - serializer: Serializer = Serializer.default): Iterator[(K,V)] + serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 8ba52245fa..2fa97cd829 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -7,7 +7,7 @@ import spark.broadcast.BroadcastManager import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager -import spark.serializer.Serializer +import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils @@ -21,6 +21,7 @@ import spark.util.AkkaUtils class SparkEnv ( val executorId: String, val actorSystem: ActorSystem, + val serializerManager: SerializerManager, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -92,10 +93,12 @@ object SparkEnv extends Logging { Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializer = Serializer.setDefault( + val serializerManager = new SerializerManager + + val serializer = serializerManager.setDefault( System.getProperty("spark.serializer", "spark.JavaSerializer")) - val closureSerializer = Serializer.get( + val closureSerializer = serializerManager.get( System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) def registerOrLookup(name: String, newActor: => Actor): ActorRef = { @@ -155,6 +158,7 @@ object SparkEnv extends Logging { new SparkEnv( executorId, actorSystem, + serializerManager, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 9e996e9958..7599ba1a02 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -8,7 +8,6 @@ import scala.collection.mutable.ArrayBuffer import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} -import spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -114,7 +113,7 @@ class CoGroupedRDD[K]( } } - val ser = Serializer.get(serializerClass) + val ser = SparkEnv.get.serializerManager.get(serializerClass) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 8175e23eff..c7d1926b83 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -2,7 +2,6 @@ package spark.rdd import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} import spark.SparkContext._ -import spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -32,7 +31,7 @@ class ShuffledRDD[K, V]( override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[K, V]( - shuffledId, split.index, context.taskMetrics, Serializer.get(serializerClass)) + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.serializerManager.get(serializerClass)) } } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index f60c35c38e..8a9efc5da2 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -11,7 +11,7 @@ import spark.Partition import spark.SparkEnv import spark.ShuffleDependency import spark.OneToOneDependency -import spark.serializer.Serializer + /** * An optimized version of cogroup for set difference/subtraction. @@ -68,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val serializer = Serializer.get(serializerClass) + val serializer = SparkEnv.get.serializerManager.get(serializerClass) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 124d2d7e26..f097213ab5 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,7 +14,6 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.executor.ShuffleWriteMetrics -import spark.serializer.Serializer import spark.storage._ import spark.util.{TimeStampedHashMap, MetadataCleaner} @@ -139,12 +138,12 @@ private[spark] class ShuffleMapTask( metrics = Some(taskContext.taskMetrics) val blockManager = SparkEnv.get.blockManager - var shuffle: ShuffleBlockManager#Shuffle = null + var shuffle: ShuffleBlocks = null var buckets: ShuffleWriterGroup = null try { // Obtain all the block writers for shuffle blocks. - val ser = Serializer.get(dep.serializerClass) + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) buckets = shuffle.acquireWriters(partition) diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index 77b1a1a434..2ad73b711d 100644 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -2,7 +2,6 @@ package spark.serializer import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream @@ -19,47 +18,6 @@ trait Serializer { } -/** - * A singleton object that can be used to fetch serializer objects based on the serializer - * class name. If a previous instance of the serializer object has been created, the get - * method returns that instead of creating a new one. - */ -object Serializer { - - private val serializers = new ConcurrentHashMap[String, Serializer] - private var _default: Serializer = _ - - def default = _default - - def setDefault(clsName: String): Serializer = { - _default = get(clsName) - _default - } - - def get(clsName: String): Serializer = { - if (clsName == null) { - default - } else { - var serializer = serializers.get(clsName) - if (serializer != null) { - // If the serializer has been created previously, reuse that. - serializer - } else this.synchronized { - // Otherwise, create a new one. But make sure no other thread has attempted - // to create another new one at the same time. - serializer = serializers.get(clsName) - if (serializer == null) { - val clsLoader = Thread.currentThread.getContextClassLoader - serializer = Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] - serializers.put(clsName, serializer) - } - serializer - } - } - } -} - - /** * An instance of a serializer, for use by one thread at a time. */ diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala new file mode 100644 index 0000000000..60b2aac797 --- /dev/null +++ b/core/src/main/scala/spark/serializer/SerializerManager.scala @@ -0,0 +1,45 @@ +package spark.serializer + +import java.util.concurrent.ConcurrentHashMap + + +/** + * A service that returns a serializer object given the serializer's class name. If a previous + * instance of the serializer object has been created, the get method returns that instead of + * creating a new one. + */ +private[spark] class SerializerManager { + + private val serializers = new ConcurrentHashMap[String, Serializer] + private var _default: Serializer = _ + + def default = _default + + def setDefault(clsName: String): Serializer = { + _default = get(clsName) + _default + } + + def get(clsName: String): Serializer = { + if (clsName == null) { + default + } else { + var serializer = serializers.get(clsName) + if (serializer != null) { + // If the serializer has been created previously, reuse that. + serializer + } else this.synchronized { + // Otherwise, create a new one. But make sure no other thread has attempted + // to create another new one at the same time. + serializer = serializers.get(clsName) + if (serializer == null) { + val clsLoader = Thread.currentThread.getContextClassLoader + serializer = + Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + serializers.put(clsName, serializer) + } + serializer + } + } + } +} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 4cddcc86fc..498bc9eeb6 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -2,6 +2,7 @@ package spark.storage import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} import java.nio.ByteBuffer +import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode import java.util.{Random, Date} import java.text.SimpleDateFormat @@ -26,14 +27,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private val f: File = createFile(blockId /*, allowAppendExisting */) - private var repositionableStream: FastBufferedOutputStream = null + // The file channel, used for repositioning / truncating the file. + private var channel: FileChannel = null private var bs: OutputStream = null private var objOut: SerializationStream = null - private var validLength = 0L + private var lastValidPosition = 0L override def open(): DiskBlockObjectWriter = { - repositionableStream = new FastBufferedOutputStream(new FileOutputStream(f)) - bs = blockManager.wrapForCompression(blockId, repositionableStream) + val fos = new FileOutputStream(f, true) + channel = fos.getChannel() + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) this } @@ -41,9 +44,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def close() { objOut.close() bs.close() - objOut = null + channel = null bs = null - repositionableStream = null + objOut = null // Invoke the close callback handler. super.close() } @@ -54,25 +57,23 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Return the number of bytes written for this commit. override def commit(): Long = { bs.flush() - validLength = repositionableStream.position() - validLength + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos } override def revertPartialWrites() { - // Flush the outstanding writes and delete the file. - objOut.close() - bs.close() - objOut = null - bs = null - repositionableStream = null - f.delete() + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + bs.flush() + channel.truncate(lastValidPosition) } override def write(value: Any) { objOut.writeObject(value) } - override def size(): Long = validLength + override def size(): Long = lastValidPosition } val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -86,7 +87,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) addShutdownHook() - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int): BlockObjectWriter = { + def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + : BlockObjectWriter = { new DiskBlockObjectWriter(blockId, serializer, bufferSize) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 1903df0817..49eabfb0d2 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -8,26 +8,30 @@ class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { +trait ShuffleBlocks { + def acquireWriters(mapId: Int): ShuffleWriterGroup + def releaseWriters(group: ShuffleWriterGroup) +} - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): Shuffle = { - new Shuffle(shuffleId, numBuckets, serializer) - } - class Shuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) { +private[spark] +class ShuffleBlockManager(blockManager: BlockManager) { - // Get a group of writers for a map task. - def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { + new ShuffleBlocks { + // Get a group of writers for a map task. + override def acquireWriters(mapId: Int): ShuffleWriterGroup = { + val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + } + new ShuffleWriterGroup(mapId, writers) } - new ShuffleWriterGroup(mapId, writers) - } - def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + override def releaseWriters(group: ShuffleWriterGroup) = { + // Nothing really to release here. + } } } } -- cgit v1.2.3 From 6fae936088d2a50606ba5082cee4a3c3a98a2b01 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 2 May 2013 22:30:06 -0700 Subject: applications (aka drivers) send their webUI address to master when registering so it can be displayed in the master web ui --- .../scala/spark/deploy/ApplicationDescription.scala | 4 +++- .../src/main/scala/spark/deploy/client/TestClient.scala | 2 +- .../scala/spark/deploy/master/ApplicationInfo.scala | 8 +++++++- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- core/src/main/scala/spark/storage/BlockManagerUI.scala | 17 +++++++++-------- .../twirl/spark/deploy/master/app_details.scala.html | 1 + 7 files changed, 24 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala index 6659e53b25..bb9e7b3bba 100644 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala @@ -5,7 +5,9 @@ private[spark] class ApplicationDescription( val cores: Int, val memoryPerSlave: Int, val command: Command, - val sparkHome: String) + val sparkHome: String, + val appUIHost: String, + val appUIPort: Int) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index ad92532b58..e4ab01dd2a 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -25,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "localhost", 0) val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala index 3591a94072..3ee1b60351 100644 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala @@ -10,7 +10,9 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef) + val driver: ActorRef, + val appUIHost: String, + val appUIPort: Int) { var state = ApplicationState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] @@ -60,4 +62,8 @@ private[spark] class ApplicationInfo( System.currentTimeMillis() - startTime } } + + + def appUIAddress = "http://" + this.appUIHost + ":" + this.appUIPort + } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 160afe5239..9f2d3da495 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -244,7 +244,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver) + val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUIHost, desc.appUIPort) apps += app idToApp(app.id) = app actorToApp(driver) = app diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0b8922d139..5d7d1feb74 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -31,7 +31,8 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome) + val appDesc = + new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, sc.ui.host, sc.ui.port) client = new Client(sc.env.actorSystem, master, appDesc, this) client.start() diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 07da572044..13158e4262 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -20,19 +20,20 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, val STATIC_RESOURCE_DIR = "spark/deploy/static" implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val host = Utils.localHostName() + val port = if (System.getProperty("spark.ui.port") != null) { + System.getProperty("spark.ui.port").toInt + } else { + // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which + // random port it bound to, so we have to try to find a local one by creating a socket. + Utils.findFreePort() + } /** Start a HTTP server to run the Web interface */ def start() { try { - val port = if (System.getProperty("spark.ui.port") != null) { - System.getProperty("spark.ui.port").toInt - } else { - // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which - // random port it bound to, so we have to try to find a local one by creating a socket. - Utils.findFreePort() - } AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer") - logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port)) + logInfo("Started BlockManager web UI at http://%s:%d".format(host, port)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html index 301a7e2124..02086b476f 100644 --- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html @@ -22,6 +22,7 @@
  • Memory per Slave: @app.desc.memoryPerSlave
  • Submit Date: @app.submitDate
  • State: @app.state
  • +
  • Application Detail UI
  • -- cgit v1.2.3 From bb8a434f9db177d764d169a7273c66ed01c066c1 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 3 May 2013 15:14:02 -0700 Subject: Add zipPartitions to Java API. --- .../main/scala/spark/api/java/JavaRDDLike.scala | 15 +++++++++++++ .../spark/api/java/function/FlatMapFunction2.scala | 11 +++++++++ core/src/test/scala/spark/JavaAPISuite.java | 26 ++++++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index d884529d7a..9b74d1226f 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -182,6 +182,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) } + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[U, V]( + f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V], + other: JavaRDDLike[U, _]): JavaRDD[V] = { + def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( + f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) + JavaRDD.fromRDD( + rdd.zipPartitions(fn, other.rdd)(other.classManifest, f.elementType()))(f.elementType()) + } + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala new file mode 100644 index 0000000000..6044043add --- /dev/null +++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala @@ -0,0 +1,11 @@ +package spark.api.java.function + +/** + * A function that takes two inputs and returns zero or more output records. + */ +abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { + @throws(classOf[Exception]) + def call(a: A, b:B) : java.lang.Iterable[C] + + def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index d3dcd3bbeb..93bb69b41c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -632,6 +632,32 @@ public class JavaAPISuite implements Serializable { zipped.count(); } + @Test + public void zipPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2, Iterator, Integer> sizesFn = + new FlatMapFunction2, Iterator, Integer>() { + @Override + public Iterable call(Iterator i, Iterator s) { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + } + }; + + JavaRDD sizes = rdd1.zipPartitions(sizesFn, rdd2); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + @Test public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); -- cgit v1.2.3 From 2274ad0786b758e3170e96815e9693ea27635f06 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 3 May 2013 16:29:36 -0700 Subject: Fix flaky test by changing catch and adding sleep --- core/src/test/scala/spark/DistributedSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index ab3e197035..a13c88cfb4 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -261,9 +261,9 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count - assert(sc.persistentRdds.isEmpty == false) + assert(sc.persistentRdds.isEmpty === false) data.unpersist() - assert(sc.persistentRdds.isEmpty == true) + assert(sc.persistentRdds.isEmpty === true) failAfter(Span(3000, Millis)) { try { @@ -271,12 +271,12 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter Thread.sleep(200) } } catch { - case e: Exception => + case _ => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } } - assert(sc.getRDDStorageInfo.isEmpty == true) + assert(sc.getRDDStorageInfo.isEmpty === true) } } -- cgit v1.2.3 From edb57c8331738403d66c15ed99996e8bfb0488f7 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sat, 4 May 2013 19:47:45 +0530 Subject: Add support for instance local in getPreferredLocations of ZippedPartitionsBaseRDD. Add comments to both ZippedPartitionsBaseRDD and ZippedRDD to better describe the potential problem with the approach --- .../main/scala/spark/rdd/ZippedPartitionsRDD.scala | 28 +++++++++++++++++++--- core/src/main/scala/spark/rdd/ZippedRDD.scala | 20 ++++++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index fc3f29ffcd..dd9f3c2680 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} private[spark] class ZippedPartitionsPartition( @@ -38,9 +38,31 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } override def getPreferredLocations(s: Partition): Seq[String] = { + // Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below + // become diminishingly small : so we might need to look at alternate strategies to alleviate this. + // If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the + // cluster - paying with n/w and cache cost. + // Maybe pick a node which figures max amount of time ? + // Choose node which is hosting 'larger' of some subset of blocks ? + // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions - val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) - preferredLocations.reduce((x, y) => x.intersect(y)) + val rddSplitZip = rdds.zip(splits) + + // exact match. + val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2)) + val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) + + // Remove exact match and then do host local match. + val otherNodePreferredLocations = rddSplitZip.map(x => { + x._1.preferredLocations(x._2).map(hostPort => { + val host = Utils.parseHostPort(hostPort)._1 + + if (exactMatchLocations.contains(host)) null else host + }).filter(_ != null) + }) + val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y)) + + otherNodeLocalLocations ++ exactMatchLocations } override def clearDependencies() { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 51573fe68a..f728e93d24 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -48,21 +48,27 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( } override def getPreferredLocations(s: Partition): Seq[String] = { + // Note that as number of slaves in cluster increase, the computed preferredLocations can become small : so we might need + // to look at alternate strategies to alleviate this. (If there are no (or very small number of preferred locations), we + // will end up transferred the blocks to 'any' node in the cluster - paying with n/w and cache cost. + // Maybe pick one or the other ? (so that atleast one block is local ?). + // Choose node which is hosting 'larger' of the blocks ? + // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions val pref1 = rdd1.preferredLocations(partition1) val pref2 = rdd2.preferredLocations(partition2) - // both partitions are instance local. - val instanceLocalLocations = pref1.intersect(pref2) + // exact match - instance local and host local. + val exactMatchLocations = pref1.intersect(pref2) - // remove locations which are already handled via instanceLocalLocations, and intersect where both partitions are node local. - val nodeLocalPref1 = pref1.filter(loc => ! instanceLocalLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) - val nodeLocalPref2 = pref2.filter(loc => ! instanceLocalLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) - val nodeLocalLocations = nodeLocalPref1.intersect(nodeLocalPref2) + // remove locations which are already handled via exactMatchLocations, and intersect where both partitions are node local. + val otherNodeLocalPref1 = pref1.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val otherNodeLocalPref2 = pref2.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) + val otherNodeLocalLocations = otherNodeLocalPref1.intersect(otherNodeLocalPref2) // Can have mix of instance local (hostPort) and node local (host) locations as preference ! - instanceLocalLocations ++ nodeLocalLocations + exactMatchLocations ++ otherNodeLocalLocations } override def clearDependencies() { -- cgit v1.2.3 From 02e8cfa61792f296555c7ed16613a91d895181a1 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Sat, 4 May 2013 12:31:30 -0500 Subject: HBase example --- .../src/main/scala/spark/examples/HBaseTest.scala | 34 ++++++++++++++++++++++ project/SparkBuild.scala | 6 +++- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..90ff64b483 --- /dev/null +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -0,0 +1,34 @@ +package spark.examples + +import spark._ +import spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, HColumnDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase tables if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val colDesc = new HColumnDescriptor(args(2)) + val tableDesc = new HTableDescriptor(args(1)) + tableDesc.addFamily(colDesc) + admin.createTable(tableDesc) + } + + val hBaseRDD = new NewHadoopRDD(sc, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result], conf) + + hBaseRDD.count() + + System.exit(0) + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 190d723435..6f5607d31c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -200,7 +200,11 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") + resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), + libraryDependencies ++= Seq( + "com.twitter" % "algebird-core_2.9.2" % "0.1.11", + "org.apache.hbase" % "hbase" % "0.94.6" + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From 9290f16430f92c66d4ec3b1ec76e491ae7cf26dc Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Sat, 4 May 2013 12:39:14 -0500 Subject: Remove unnecessary column family config --- examples/src/main/scala/spark/examples/HBaseTest.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala index 90ff64b483..37aedde302 100644 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -2,7 +2,7 @@ package spark.examples import spark._ import spark.rdd.NewHadoopRDD -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, HColumnDescriptor} +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.mapreduce.TableInputFormat @@ -14,12 +14,10 @@ object HBaseTest { val conf = HBaseConfiguration.create() conf.set(TableInputFormat.INPUT_TABLE, args(1)) - // Initialize hBase tables if necessary + // Initialize hBase table if necessary val admin = new HBaseAdmin(conf) if(!admin.isTableAvailable(args(1))) { - val colDesc = new HColumnDescriptor(args(2)) val tableDesc = new HTableDescriptor(args(1)) - tableDesc.addFamily(colDesc) admin.createTable(tableDesc) } -- cgit v1.2.3 From 7cff7e789723b646d1692fc71ef99f89a862bdc6 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Sat, 4 May 2013 14:56:55 -0500 Subject: Fix indents and mention other configuration options --- examples/src/main/scala/spark/examples/HBaseTest.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala index 37aedde302..d94b25828d 100644 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -12,6 +12,9 @@ object HBaseTest { System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html conf.set(TableInputFormat.INPUT_TABLE, args(1)) // Initialize hBase table if necessary @@ -22,8 +25,8 @@ object HBaseTest { } val hBaseRDD = new NewHadoopRDD(sc, classOf[TableInputFormat], - classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], - classOf[org.apache.hadoop.hbase.client.Result], conf) + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result], conf) hBaseRDD.count() -- cgit v1.2.3 From d48e9fde01cec2a7db794edf4cbe66c2228531aa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 4 May 2013 12:36:47 -0700 Subject: Fix SPARK-629: weird number of cores in job details page. --- core/src/main/scala/spark/deploy/ApplicationDescription.scala | 2 +- core/src/main/scala/spark/deploy/JsonProtocol.scala | 4 ++-- core/src/main/scala/spark/deploy/master/ApplicationInfo.scala | 2 +- .../src/main/twirl/spark/deploy/master/app_details.scala.html | 11 ++++------- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala index 6659e53b25..b6b9f9bf9d 100644 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala @@ -2,7 +2,7 @@ package spark.deploy private[spark] class ApplicationDescription( val name: String, - val cores: Int, + val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */ val memoryPerSlave: Int, val command: Command, val sparkHome: String) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 71a641a9ef..ea832101d2 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -26,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "starttime" -> JsNumber(obj.startTime), "id" -> JsString(obj.id), "name" -> JsString(obj.desc.name), - "cores" -> JsNumber(obj.desc.cores), + "cores" -> JsNumber(obj.desc.maxCores), "user" -> JsString(obj.desc.user), "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), "submitdate" -> JsString(obj.submitDate.toString)) @@ -35,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] { def write(obj: ApplicationDescription) = JsObject( "name" -> JsString(obj.name), - "cores" -> JsNumber(obj.cores), + "cores" -> JsNumber(obj.maxCores), "memoryperslave" -> JsNumber(obj.memoryPerSlave), "user" -> JsString(obj.user) ) diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala index 3591a94072..70e5caab66 100644 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala @@ -37,7 +37,7 @@ private[spark] class ApplicationInfo( coresGranted -= exec.cores } - def coresLeft: Int = desc.cores - coresGranted + def coresLeft: Int = desc.maxCores - coresGranted private var _retryCount = 0 diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html index 301a7e2124..66147e213f 100644 --- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html @@ -9,15 +9,12 @@
  • ID: @app.id
  • Description: @app.desc.name
  • User: @app.desc.user
  • -
  • Cores: - @app.desc.cores - (@app.coresGranted Granted - @if(app.desc.cores == Integer.MAX_VALUE) { - +
  • Cores: + @if(app.desc.maxCores == Integer.MAX_VALUE) { + Unlimited (@app.coresGranted granted) } else { - , @app.coresLeft + @app.desc.maxCores (@app.coresGranted granted, @app.coresLeft left) } - )
  • Memory per Slave: @app.desc.memoryPerSlave
  • Submit Date: @app.submitDate
  • -- cgit v1.2.3 From c0688451a6a91f596d9c596383026ddbdcbb8bb0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 4 May 2013 12:37:04 -0700 Subject: Fix wrong closing tags in web UI HTML. --- core/src/main/twirl/spark/deploy/master/executor_row.scala.html | 2 +- core/src/main/twirl/spark/deploy/master/worker_row.scala.html | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html index d2d80fad48..21e72c7aab 100644 --- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html @@ -3,7 +3,7 @@ @executor.id - @executor.worker.id + @executor.worker.id @executor.cores @executor.memory diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html index be69e9bf02..46277ca421 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html @@ -4,7 +4,7 @@ - @worker.id + @worker.id @{worker.host}:@{worker.port} @worker.state -- cgit v1.2.3 From 42b1953c5385135467b87e79266a0f6c23d63e7b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 4 May 2013 13:23:31 -0700 Subject: Fix SPARK-630: app details page shows finished executors as running. --- core/src/main/scala/spark/deploy/master/Master.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 160afe5239..707fe57983 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -275,6 +275,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) exec.worker.actor ! KillExecutor(exec.application.id, exec.id) + exec.state = ExecutorState.KILLED } app.markFinished(state) app.driver ! ApplicationRemoved(state.toString) -- cgit v1.2.3 From 0a2bed356b9ea604d317be97a6747588c5af29e4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 4 May 2013 21:50:08 -0700 Subject: Fixed flaky unpersist test in DistributedSuite. --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index a13c88cfb4..4df3bb5b67 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -276,7 +276,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter // is racing this thread to remove entries from the driver. } } - assert(sc.getRDDStorageInfo.isEmpty === true) } } -- cgit v1.2.3 From e014c1d1cb4bd1037dc674ef474d0197267b399b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 5 May 2013 11:30:36 -0700 Subject: Fix SPARK-670: EC2 start command should require -i option. --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 9f2daad2b6..7affe6fffc 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -103,7 +103,7 @@ def parse_args(): parser.print_help() sys.exit(1) (action, cluster_name) = args - if opts.identity_file == None and action in ['launch', 'login']: + if opts.identity_file == None and action in ['launch', 'login', 'start']: print >> stderr, ("ERROR: The -i or --identity-file argument is " + "required for " + action) sys.exit(1) -- cgit v1.2.3 From 22a5063ae45300cdd759f51877dac850008e16ee Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 5 May 2013 12:19:11 -0700 Subject: switch from separating appUI host & port to combining into just appUiUrl --- core/src/main/scala/spark/deploy/ApplicationDescription.scala | 3 +-- core/src/main/scala/spark/deploy/client/TestClient.scala | 2 +- core/src/main/scala/spark/deploy/master/ApplicationInfo.scala | 6 +----- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- core/src/main/scala/spark/storage/BlockManagerUI.scala | 2 ++ core/src/main/twirl/spark/deploy/master/app_details.scala.html | 2 +- 7 files changed, 8 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala index bb9e7b3bba..4aff0aedc1 100644 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala @@ -6,8 +6,7 @@ private[spark] class ApplicationDescription( val memoryPerSlave: Int, val command: Command, val sparkHome: String, - val appUIHost: String, - val appUIPort: Int) + val appUiUrl: String) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index e4ab01dd2a..f195082808 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -25,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "localhost", 0) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala index 3ee1b60351..e28b007e30 100644 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala @@ -11,8 +11,7 @@ private[spark] class ApplicationInfo( val desc: ApplicationDescription, val submitDate: Date, val driver: ActorRef, - val appUIHost: String, - val appUIPort: Int) + val appUiUrl: String) { var state = ApplicationState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] @@ -63,7 +62,4 @@ private[spark] class ApplicationInfo( } } - - def appUIAddress = "http://" + this.appUIHost + ":" + this.appUIPort - } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 9f2d3da495..6f58ad16af 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -244,7 +244,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUIHost, desc.appUIPort) + val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) apps += app idToApp(app.id) = app actorToApp(driver) = app diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5d7d1feb74..955ee5d806 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -32,7 +32,7 @@ private[spark] class SparkDeploySchedulerBackend( val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) val appDesc = - new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, sc.ui.host, sc.ui.port) + new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, sc.ui.appUIAddress) client = new Client(sc.env.actorSystem, master, appDesc, this) client.start() diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 13158e4262..e02281344a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -74,4 +74,6 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, } } } + + private[spark] def appUIAddress = "http://" + host + ":" + port } diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html index 02086b476f..15eabc9834 100644 --- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html @@ -22,7 +22,7 @@
  • Memory per Slave: @app.desc.memoryPerSlave
  • Submit Date: @app.submitDate
  • State: @app.state
  • -
  • Application Detail UI
  • +
  • Application Detail UI
  • -- cgit v1.2.3 From cbf6a5ee1e7d290d04a0c5dac78d360266d415a4 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 6 May 2013 08:05:45 -0600 Subject: Removed unused code, clarified intent of the program, batch size to 1 second --- .../scala/spark/streaming/examples/StatefulNetworkWordCount.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala index b662cb1162..51c3c9f9b4 100644 --- a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -4,7 +4,7 @@ import spark.streaming._ import spark.streaming.StreamingContext._ /** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second. * Usage: StatefulNetworkWordCount * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. * and describe the TCP server that Spark Streaming would connect to receive data. @@ -15,8 +15,6 @@ import spark.streaming.StreamingContext._ * `$ ./run spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` */ object StatefulNetworkWordCount { - private def className[A](a: A)(implicit m: Manifest[A]) = m.toString - def main(args: Array[String]) { if (args.length < 3) { System.err.println("Usage: StatefulNetworkWordCount \n" + @@ -32,8 +30,8 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } - // Create the context with a 10 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(10), + // Create the context with a 1 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) ssc.checkpoint(".") -- cgit v1.2.3 From 0fd84965f66aa37d2ae14da799b86a5c8ed1cb32 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 6 May 2013 15:35:18 -0700 Subject: Added EmptyRDD. --- core/src/main/scala/spark/rdd/EmptyRDD.scala | 16 ++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 14 +++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/spark/rdd/EmptyRDD.scala diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala new file mode 100644 index 0000000000..e4dd3a7fa7 --- /dev/null +++ b/core/src/main/scala/spark/rdd/EmptyRDD.scala @@ -0,0 +1,16 @@ +package spark.rdd + +import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} + + +/** + * An RDD that is empty, i.e. has no element in it. + */ +class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { + + override def getPartitions: Array[Partition] = Array.empty + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + throw new UnsupportedOperationException("empty RDD") + } +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index cee6312572..2ce757b13c 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -5,7 +5,7 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.{Span, Millis} import spark.SparkContext._ -import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD} +import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD} class RDDSuite extends FunSuite with LocalSparkContext { @@ -147,6 +147,18 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(rdd.collect().toList === List(1, 2, 3, 4)) } + test("empty RDD") { + sc = new SparkContext("local", "test") + val empty = new EmptyRDD[Int](sc) + assert(empty.count === 0) + assert(empty.collect().size === 0) + + val thrown = intercept[UnsupportedOperationException]{ + empty.reduce(_+_) + } + assert(thrown.getMessage.contains("empty")) + } + test("cogrouped RDDs") { sc = new SparkContext("local", "test") val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2) -- cgit v1.2.3 From 64d4d2b036447f42bfcd3bac5687c79a3b0661ca Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 6 May 2013 16:30:46 -0700 Subject: Added tests for joins, cogroups, and unions for EmptyRDD. --- core/src/test/scala/spark/RDDSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 2ce757b13c..a761dd77c5 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -157,6 +157,14 @@ class RDDSuite extends FunSuite with LocalSparkContext { empty.reduce(_+_) } assert(thrown.getMessage.contains("empty")) + + val emptyKv = new EmptyRDD[(Int, Int)](sc) + val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) + assert(rdd.join(emptyKv).collect().size === 0) + assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) + assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.cogroup(emptyKv).collect().size === 2) + assert(rdd.union(emptyKv).collect().size === 2) } test("cogrouped RDDs") { -- cgit v1.2.3 From 4d8919d33056a006ebf6b1ddb0509aeccaa828d7 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 18 Apr 2013 14:58:38 -0700 Subject: Update Maven build to Scala 2.9.3 --- core/pom.xml | 4 ++-- pom.xml | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index da26d674ec..9a019b5a42 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -73,7 +73,7 @@ cc.spray - spray-json_${scala.version} + spray-json_2.9.2 org.tomdz.twirl @@ -81,7 +81,7 @@ com.github.scala-incubator.io - scala-io-file_${scala.version} + scala-io-file_2.9.2 org.apache.mesos diff --git a/pom.xml b/pom.xml index c3323ffad0..3936165d78 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,7 @@ UTF-8 1.5 - 2.9.2 + 2.9.3 0.9.0-incubating 2.0.3 1.0-M2.1 @@ -238,7 +238,7 @@ cc.spray - spray-json_${scala.version} + spray-json_2.9.2 ${spray.json.version} @@ -248,7 +248,7 @@ com.github.scala-incubator.io - scala-io-file_${scala.version} + scala-io-file_2.9.2 0.4.1 @@ -277,7 +277,7 @@ org.scalatest scalatest_${scala.version} - 1.8 + 1.9.1 test @@ -289,7 +289,7 @@ org.scalacheck scalacheck_${scala.version} - 1.9 + 1.10.0 test @@ -513,7 +513,6 @@ hadoop1 - 1 -- cgit v1.2.3 From a3d5f922109caa878f8350fe0634514b8af55cbc Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Tue, 7 May 2013 11:43:06 -0500 Subject: Switch to using SparkContext method to create RDD --- examples/src/main/scala/spark/examples/HBaseTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala index d94b25828d..9bad876860 100644 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -24,9 +24,9 @@ object HBaseTest { admin.createTable(tableDesc) } - val hBaseRDD = new NewHadoopRDD(sc, classOf[TableInputFormat], + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], - classOf[org.apache.hadoop.hbase.client.Result], conf) + classOf[org.apache.hadoop.hbase.client.Result]) hBaseRDD.count() -- cgit v1.2.3 From aacca1b8a85bd073ce185a06d6470b070761b2f4 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 18 Apr 2013 14:58:38 -0700 Subject: Update Maven build to Scala 2.9.3 --- core/pom.xml | 4 ++-- pom.xml | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index da26d674ec..9a019b5a42 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -73,7 +73,7 @@ cc.spray - spray-json_${scala.version} + spray-json_2.9.2 org.tomdz.twirl @@ -81,7 +81,7 @@ com.github.scala-incubator.io - scala-io-file_${scala.version} + scala-io-file_2.9.2 org.apache.mesos diff --git a/pom.xml b/pom.xml index c3323ffad0..3936165d78 100644 --- a/pom.xml +++ b/pom.xml @@ -51,7 +51,7 @@ UTF-8 1.5 - 2.9.2 + 2.9.3 0.9.0-incubating 2.0.3 1.0-M2.1 @@ -238,7 +238,7 @@ cc.spray - spray-json_${scala.version} + spray-json_2.9.2 ${spray.json.version} @@ -248,7 +248,7 @@ com.github.scala-incubator.io - scala-io-file_${scala.version} + scala-io-file_2.9.2 0.4.1 @@ -277,7 +277,7 @@ org.scalatest scalatest_${scala.version} - 1.8 + 1.9.1 test @@ -289,7 +289,7 @@ org.scalacheck scalacheck_${scala.version} - 1.9 + 1.10.0 test @@ -513,7 +513,6 @@ hadoop1 - 1 -- cgit v1.2.3 From 8b7948517182ced5b3681dfc668732364ebccc38 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 17:02:32 -0700 Subject: Moved BlockFetcherIterator to its own file. --- .../scala/spark/storage/BlockFetcherIterator.scala | 361 +++++++++++++++++++++ .../main/scala/spark/storage/BlockManager.scala | 340 +------------------ 2 files changed, 362 insertions(+), 339 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockFetcherIterator.scala diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala new file mode 100644 index 0000000000..30990d9a38 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -0,0 +1,361 @@ +package spark.storage + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import spark.Logging +import spark.Utils +import spark.SparkException + +import spark.network.BufferMessage +import spark.network.ConnectionManagerId +import spark.network.netty.ShuffleCopier + +import spark.serializer.Serializer +import io.netty.buffer.ByteBuf + + +trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] + with Logging with BlockFetchTracker { + + def initialize() + +} + + + +object BlockFetcherIterator { + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } + + class BasicBlockFetcherIterator( + private val blockManager: BlockManager, + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer + ) extends BlockFetcherIterator { + + import blockManager._ + + private var _remoteBytesRead = 0l + private var _remoteFetchTime = 0l + private var _fetchWaitTime = 0l + + if (blocksByAddress == null) { + throw new IllegalArgumentException("BlocksByAddress is null") + } + var totalBlocks = blocksByAddress.map(_._2.size).sum + logDebug("Getting " + totalBlocks + " blocks") + var startTime = System.currentTimeMillis + val localBlockIds = new ArrayBuffer[String]() + val remoteBlockIds = new HashSet[String]() + + // A queue to hold our results. + val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + var bytesInFlight = 0L + + def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) + val blockMessageArray = new BlockMessageArray(req.blocks.map { + case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) + }) + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + val fetchStart = System.currentTimeMillis() + val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + future.onSuccess { + case Some(message) => { + val fetchDone = System.currentTimeMillis() + _remoteFetchTime += fetchDone - fetchStart + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) + } + val blockId = blockMessage.getId + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += req.size + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + } + case None => { + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + } + + def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + remoteRequests + } + + def getLocalBlocks(){ + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + } + } + + override def initialize(){ + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + + } + + //an iterator that will read fetched blocks off the queue as they arrive. + @volatile protected var resultsGotten = 0 + + def hasNext: Boolean = resultsGotten < totalBlocks + + def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + _fetchWaitTime += (stopFetchWait - startFetchWait) + if (! result.failed) bytesInFlight -= result.size + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + + + //methods to profile the block fetching + def numLocalBlocks = localBlockIds.size + def numRemoteBlocks = remoteBlockIds.size + + def remoteFetchTime = _remoteFetchTime + def fetchWaitTime = _fetchWaitTime + + def remoteBytesRead = _remoteBytesRead + + } + + + class NettyBlockFetcherIterator( + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer + ) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) { + + import blockManager._ + + val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + + def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, + results : LinkedBlockingQueue[FetchResult]){ + results.put(new FetchResult( + blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) )) + } + + def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + (for ( i <- Range(0,numCopiers) ) yield { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { + sendRequest(fetchRequestsSync.take()) + } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } + } + } + copier.start + copier + }).toList + } + + //keep this to interrupt the threads when necessary + def stopCopiers(copiers : List[_ <: Thread]) { + for (copier <- copiers) { + copier.interrupt() + } + } + + override def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) + val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cpier = new ShuffleCopier + cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) + } + + override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = totalBlocks; + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0){ + //here we changes the totalBlocks + totalBlocks -= 1 + } else { + throw new SparkException("Negative block size "+blockId) + } + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] + } + } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + remoteRequests + } + + var copiers : List[_ <: Thread] = null + + override def initialize(){ + // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + for (request <- Utils.randomize(remoteRequests)) { + fetchRequestsSync.put(request) + } + + copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + startTime = System.currentTimeMillis + getLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def next(): (String, Option[Iterator[Any]]) = { + resultsGotten += 1 + val result = results.take() + // if all the results has been retrieved + // shutdown the copiers + if (resultsGotten == totalBlocks) { + if( copiers != null ) + stopCopiers(copiers) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } + } + + def apply(t: String, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + serializer: Serializer): BlockFetcherIterator = { + val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) } + else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) } + iter.initialize() + iter + } +} + diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 433e939656..a189c1a025 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -2,9 +2,8 @@ package spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import akka.actor.{ActorSystem, Cancellable, Props} @@ -23,8 +22,6 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -import spark.network.netty.ShuffleCopier -import io.netty.buffer.ByteBuf private[spark] class BlockManager( @@ -977,338 +974,3 @@ object BlockManager extends Logging { } } - -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { - def initialize -} - -object BlockFetcherIterator { - - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - -class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer -) extends BlockFetcherIterator { - - import blockManager._ - - private var _remoteBytesRead = 0l - private var _remoteFetchTime = 0l - private var _fetchWaitTime = 0l - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - var totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + totalBlocks + " blocks") - var startTime = System.currentTimeMillis - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new HashSet[String]() - - // A queue to hold our results. - val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - var bytesInFlight = 0L - - def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val fetchStart = System.currentTimeMillis() - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { - val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += req.size - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case None => { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - remoteRequests - } - - def getLocalBlocks(){ - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { - getLocal(id) match { - case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - } - } - - def initialize(){ - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numGets = remoteBlockIds.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - - } - - //an iterator that will read fetched blocks off the queue as they arrive. - @volatile private var resultsGotten = 0 - - def hasNext: Boolean = resultsGotten < totalBlocks - - def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - - - //methods to profile the block fetching - def numLocalBlocks = localBlockIds.size - def numRemoteBlocks = remoteBlockIds.size - - def remoteFetchTime = _remoteFetchTime - def fetchWaitTime = _fetchWaitTime - - def remoteBytesRead = _remoteBytesRead - -} - -class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer -) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) { - - import blockManager._ - - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - - def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, - results : LinkedBlockingQueue[FetchResult]){ - results.put(new FetchResult( - blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) )) - } - - def startCopiers (numCopiers: Int): List [ _ <: Thread]= { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") - } - } - } - copier.start - copier - }).toList - } - - //keep this to interrupt the threads when necessary - def stopCopiers(copiers : List[_ <: Thread]) { - for (copier <- copiers) { - copier.interrupt() - } - } - - override def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) - val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt) - val cpier = new ShuffleCopier - cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip ) - } - - override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = totalBlocks; - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - if (size > 0) { - curBlocks += ((blockId, size)) - curRequestSize += size - } else if (size == 0){ - //here we changes the totalBlocks - totalBlocks -= 1 - } else { - throw new SparkException("Negative block size "+blockId) - } - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") - remoteRequests - } - - var copiers : List[_ <: Thread] = null - - override def initialize(){ - // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) - logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - override def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // if all the results has been retrieved - // shutdown the copiers - if (resultsGotten == totalBlocks) { - if( copiers != null ) - stopCopiers(copiers) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - - def apply(t: String, - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer): BlockFetcherIterator = { - val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - iter.initialize - iter - } - -} -- cgit v1.2.3 From 0e5cc30868bcf933f2980c4cfe29abc3d8fe5887 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 18:18:24 -0700 Subject: Cleaned up BlockManager and BlockFetcherIterator from Shane's PR. --- .../scala/spark/storage/BlockFetchTracker.scala | 12 +- .../scala/spark/storage/BlockFetcherIterator.scala | 167 ++++++++++----------- .../main/scala/spark/storage/BlockManager.scala | 22 +-- 3 files changed, 102 insertions(+), 99 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala index 993aece1f7..0718156b1b 100644 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala @@ -1,10 +1,10 @@ package spark.storage private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long + def totalBlocks : Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def remoteFetchTime : Long + def fetchWaitTime: Long + def remoteBytesRead : Long } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 30990d9a38..43f835237c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -7,27 +7,36 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import io.netty.buffer.ByteBuf + import spark.Logging import spark.Utils import spark.SparkException - import spark.network.BufferMessage import spark.network.ConnectionManagerId import spark.network.netty.ShuffleCopier - import spark.serializer.Serializer -import io.netty.buffer.ByteBuf +/** + * A block fetcher iterator interface. There are two implementations: + * + * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. + * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. + * + * Eventually we would like the two to converge and use a single NIO-based communication layer, + * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), + * NIO would perform poorly and thus the need for the Netty OIO one. + */ + +private[storage] trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker { - def initialize() - } - +private[storage] object BlockFetcherIterator { // A request to fetch one or more blocks, complete with their sizes @@ -45,8 +54,8 @@ object BlockFetcherIterator { class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer - ) extends BlockFetcherIterator { + serializer: Serializer) + extends BlockFetcherIterator { import blockManager._ @@ -57,23 +66,24 @@ object BlockFetcherIterator { if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } - var totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + totalBlocks + " blocks") - var startTime = System.currentTimeMillis - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new HashSet[String]() + + protected var _totalBlocks = blocksByAddress.map(_._2.size).sum + logDebug("Getting " + _totalBlocks + " blocks") + protected var startTime = System.currentTimeMillis + protected val localBlockIds = new ArrayBuffer[String]() + protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. - val results = new LinkedBlockingQueue[FetchResult] + protected val results = new LinkedBlockingQueue[FetchResult] // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - val fetchRequests = new Queue[FetchRequest] + private val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - var bytesInFlight = 0L + private var bytesInFlight = 0L - def sendRequest(req: FetchRequest) { + protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) val cmId = new ConnectionManagerId(req.address.host, req.address.port) @@ -111,7 +121,7 @@ object BlockFetcherIterator { } } - def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = { + protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] @@ -148,14 +158,15 @@ object BlockFetcherIterator { remoteRequests } - def getLocalBlocks(){ + protected def getLocalBlocks() { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlockIds) { getLocal(id) match { case Some(iter) => { - results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight + // Pass 0 as size since it's not in flight + results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } case None => { @@ -165,7 +176,7 @@ object BlockFetcherIterator { } } - override def initialize(){ + override def initialize() { // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -184,15 +195,14 @@ object BlockFetcherIterator { startTime = System.currentTimeMillis getLocalBlocks() logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - def hasNext: Boolean = resultsGotten < totalBlocks + override def hasNext: Boolean = resultsGotten < _totalBlocks - def next(): (String, Option[Iterator[Any]]) = { + override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -206,74 +216,73 @@ object BlockFetcherIterator { (result.blockId, if (result.failed) None else Some(result.deserialize())) } - - //methods to profile the block fetching - def numLocalBlocks = localBlockIds.size - def numRemoteBlocks = remoteBlockIds.size - - def remoteFetchTime = _remoteFetchTime - def fetchWaitTime = _fetchWaitTime - - def remoteBytesRead = _remoteBytesRead - + // Implementing BlockFetchTracker trait. + override def totalBlocks: Int = _totalBlocks + override def numLocalBlocks: Int = localBlockIds.size + override def numRemoteBlocks: Int = remoteBlockIds.size + override def remoteFetchTime: Long = _remoteFetchTime + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead } - + // End of BasicBlockFetcherIterator class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer - ) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) { + serializer: Serializer) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { import blockManager._ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer, - results : LinkedBlockingQueue[FetchResult]){ - results.put(new FetchResult( - blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) )) - } - - def startCopiers (numCopiers: Int): List [ _ <: Thread]= { + private def startCopiers(numCopiers: Int): List[_ <: Thread] = { (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { + val copier = new Thread { + override def run(){ + try { + while(!isInterrupted && !fetchRequestsSync.isEmpty) { sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") } - } + } catch { + case x: InterruptedException => logInfo("Copier Interrupted") + //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + } } - copier.start - copier + } + copier.start + copier }).toList } //keep this to interrupt the threads when necessary - def stopCopiers(copiers : List[_ <: Thread]) { + private def stopCopiers() { for (copier <- copiers) { copier.interrupt() } } - override def sendRequest(req: FetchRequest) { + override protected def sendRequest(req: FetchRequest) { + + def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + val fetchResult = new FetchResult(blockId, blockSize, + () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) + results.put(fetchResult) + } + logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cmId = new ConnectionManagerId( + req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) val cpier = new ShuffleCopier - cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results)) + cpier.getBlocks(cmId, req.blocks, putResult) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } - override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = { + override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = totalBlocks; + val originalTotalBlocks = _totalBlocks; val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { @@ -293,11 +302,11 @@ object BlockFetcherIterator { if (size > 0) { curBlocks += ((blockId, size)) curRequestSize += size - } else if (size == 0){ + } else if (size == 0) { //here we changes the totalBlocks - totalBlocks -= 1 + _totalBlocks -= 1 } else { - throw new SparkException("Negative block size "+blockId) + throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { // Add this FetchRequest @@ -312,13 +321,14 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks") + logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + + originalTotalBlocks + " blocks") remoteRequests } - var copiers : List[_ <: Thread] = null + private var copiers: List[_ <: Thread] = null - override def initialize(){ + override def initialize() { // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -327,7 +337,8 @@ object BlockFetcherIterator { } copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) - logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) + logInfo("Started " + fetchRequestsSync.size + " remote gets in " + + Utils.getUsedTimeMs(startTime)) // Get Local Blocks startTime = System.currentTimeMillis @@ -338,24 +349,12 @@ object BlockFetcherIterator { override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() - // if all the results has been retrieved - // shutdown the copiers - if (resultsGotten == totalBlocks) { - if( copiers != null ) - stopCopiers(copiers) + // if all the results has been retrieved, shutdown the copiers + if (resultsGotten == _totalBlocks && copiers != null) { + stopCopiers() } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } - - def apply(t: String, - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer): BlockFetcherIterator = { - val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) } - iter.initialize() - iter - } + // End of NettyBlockFetcherIterator } - diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index a189c1a025..e0dec3a8bb 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -23,8 +23,7 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam import sun.nio.ch.DirectBuffer -private[spark] -class BlockManager( +private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, @@ -494,11 +493,16 @@ class BlockManager( def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) : BlockFetcherIterator = { - if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){ - return BlockFetcherIterator("netty",this, blocksByAddress, serializer) - } else { - return BlockFetcherIterator("", this, blocksByAddress, serializer) - } + + val iter = + if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + } else { + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + } + + iter.initialize() + iter } def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) @@ -942,8 +946,8 @@ class BlockManager( } } -private[spark] -object BlockManager extends Logging { + +private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator -- cgit v1.2.3 From 9e64396ca4c24804d5fd4e96212eed54530ca409 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 18:30:54 -0700 Subject: Cleaned up the Java files from Shane's PR. --- .../main/java/spark/network/netty/FileClient.java | 45 ++++++------- .../netty/FileClientChannelInitializer.java | 11 +--- .../spark/network/netty/FileClientHandler.java | 11 ++-- .../main/java/spark/network/netty/FileServer.java | 73 ++++++++++------------ .../netty/FileServerChannelInitializer.java | 22 +++---- .../spark/network/netty/FileServerHandler.java | 33 +++++----- .../java/spark/network/netty/PathResolver.java | 4 +- core/src/main/scala/spark/storage/DiskStore.scala | 4 +- 8 files changed, 85 insertions(+), 118 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index d0c5081dd2..3a62dacbc8 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -1,42 +1,40 @@ package spark.network.netty; import io.netty.bootstrap.Bootstrap; -import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOption; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; -import java.util.Arrays; -public class FileClient { +class FileClient { private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; - public FileClient(FileClientHandler handler){ + public FileClient(FileClientHandler handler) { this.handler = handler; } - - public void init(){ - bootstrap = new Bootstrap(); - bootstrap.group(new OioEventLoopGroup()) + + public void init() { + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) .handler(new FileClientChannelInitializer(handler)); - } + } public static final class ChannelCloseListener implements ChannelFutureListener { private FileClient fc = null; + public ChannelCloseListener(FileClient fc){ this.fc = fc; } + @Override public void operationComplete(ChannelFuture future) { if (fc.bootstrap!=null){ @@ -46,44 +44,39 @@ public class FileClient { } } - public void connect(String host, int port){ + public void connect(String host, int port) { try { - // Start the connection attempt. channel = bootstrap.connect(host, port).sync().channel(); // ChannelFuture cf = channel.closeFuture(); //cf.addListener(new ChannelCloseListener(this)); } catch (InterruptedException e) { close(); - } + } } - - public void waitForClose(){ + + public void waitForClose() { try { channel.closeFuture().sync(); } catch (InterruptedException e){ e.printStackTrace(); } - } + } - public void sendRequest(String file){ + public void sendRequest(String file) { //assert(file == null); //assert(channel == null); - channel.write(file+"\r\n"); + channel.write(file + "\r\n"); } - public void close(){ + public void close() { if(channel != null) { - channel.close(); - channel = null; + channel.close(); + channel = null; } if ( bootstrap!=null) { bootstrap.shutdown(); bootstrap = null; } } - - } - - diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java index 50e5704619..af25baf641 100644 --- a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java @@ -3,15 +3,10 @@ package spark.network.netty; import io.netty.buffer.BufType; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.string.StringEncoder; -import io.netty.util.CharsetUtil; -import io.netty.handler.logging.LoggingHandler; -import io.netty.handler.logging.LogLevel; -public class FileClientChannelInitializer extends - ChannelInitializer { +class FileClientChannelInitializer extends ChannelInitializer { private FileClientHandler fhandler; @@ -23,7 +18,7 @@ public class FileClientChannelInitializer extends public void initChannel(SocketChannel channel) { // file no more than 2G channel.pipeline() - .addLast("encoder", new StringEncoder(BufType.BYTE)) - .addLast("handler", fhandler); + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); } } diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java index 911c8b32b5..2069dee5ca 100644 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -3,12 +3,9 @@ package spark.network.netty; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundByteHandlerAdapter; -import io.netty.util.CharsetUtil; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.logging.Logger; -public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { +abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { private FileHeader currentHeader = null; @@ -19,7 +16,7 @@ public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter // Use direct buffer if possible. return ctx.alloc().ioBuffer(); } - + @Override public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { // get header @@ -27,8 +24,8 @@ public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); } // get file - if(in.readableBytes() >= currentHeader.fileLen()){ - handle(ctx,in,currentHeader); + if(in.readableBytes() >= currentHeader.fileLen()) { + handle(ctx, in, currentHeader); currentHeader = null; ctx.close(); } diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java index 38af305096..647b26bf8a 100644 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -1,58 +1,51 @@ package spark.network.netty; -import java.io.File; import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.Channel; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioServerSocketChannel; -import io.netty.handler.logging.LogLevel; -import io.netty.handler.logging.LoggingHandler; + /** * Server that accept the path of a file an echo back its content. */ -public class FileServer { +class FileServer { + + private ServerBootstrap bootstrap = null; + private Channel channel = null; + private PathResolver pResolver; - private ServerBootstrap bootstrap = null; - private Channel channel = null; - private PathResolver pResolver; + public FileServer(PathResolver pResolver) { + this.pResolver = pResolver; + } - public FileServer(PathResolver pResolver){ - this.pResolver = pResolver; + public void run(int port) { + // Configure the server. + bootstrap = new ServerBootstrap(); + try { + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channel = bootstrap.bind(port).sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally{ + bootstrap.shutdown(); } + } - public void run(int port) { - // Configure the server. - bootstrap = new ServerBootstrap(); - try { - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channel = bootstrap.bind(port).sync().channel(); - channel.closeFuture().sync(); - } catch (InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } finally{ - bootstrap.shutdown(); - } + public void stop() { + if (channel!=null) { + channel.close(); } - - public void stop(){ - if (channel!=null){ - channel.close(); - } - if (bootstrap != null){ - bootstrap.shutdown(); - } + if (bootstrap != null) { + bootstrap.shutdown(); } + } } - - diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java index 9d0618ff1c..8f1f5c65cd 100644 --- a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java @@ -1,21 +1,15 @@ package spark.network.netty; -import java.io.File; -import io.netty.buffer.BufType; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringDecoder; -import io.netty.handler.codec.string.StringEncoder; import io.netty.handler.codec.DelimiterBasedFrameDecoder; import io.netty.handler.codec.Delimiters; -import io.netty.util.CharsetUtil; -import io.netty.handler.logging.LoggingHandler; -import io.netty.handler.logging.LogLevel; +import io.netty.handler.codec.string.StringDecoder; + -public class FileServerChannelInitializer extends - ChannelInitializer { +class FileServerChannelInitializer extends ChannelInitializer { - PathResolver pResolver; + PathResolver pResolver; public FileServerChannelInitializer(PathResolver pResolver) { this.pResolver = pResolver; @@ -24,10 +18,8 @@ public class FileServerChannelInitializer extends @Override public void initChannel(SocketChannel channel) { channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder( - 8192, Delimiters.lineDelimiter())) - .addLast("strDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); - + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); } } diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java index e1083e87a2..a78eddb1b5 100644 --- a/core/src/main/java/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/spark/network/netty/FileServerHandler.java @@ -1,17 +1,17 @@ package spark.network.netty; +import java.io.File; +import java.io.FileInputStream; + import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.DefaultFileRegion; -import io.netty.handler.stream.ChunkedFile; -import java.io.File; -import java.io.FileInputStream; -public class FileServerHandler extends - ChannelInboundMessageHandlerAdapter { - PathResolver pResolver; - +class FileServerHandler extends ChannelInboundMessageHandlerAdapter { + + PathResolver pResolver; + public FileServerHandler(PathResolver pResolver){ this.pResolver = pResolver; } @@ -21,8 +21,8 @@ public class FileServerHandler extends String path = pResolver.getAbsolutePath(blockId); // if getFilePath returns null, close the channel if (path == null) { - //ctx.close(); - return; + //ctx.close(); + return; } File file = new File(path); if (file.exists()) { @@ -33,23 +33,21 @@ public class FileServerHandler extends return; } long length = file.length(); - if (length > Integer.MAX_VALUE || length <= 0 ) { + if (length > Integer.MAX_VALUE || length <= 0) { //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); ctx.write(new FileHeader(0, blockId).buffer()); ctx.flush(); - return; + return; } int len = new Long(length).intValue(); //logger.info("Sending block "+blockId+" filelen = "+len); //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); ctx.write((new FileHeader(len, blockId)).buffer()); try { - ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), 0, file.length())); + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); } catch (Exception e) { - // TODO Auto-generated catch block - //logger.warning("Exception when sending file : " - //+ file.getAbsolutePath()); + //logger.warning("Exception when sending file : " + file.getAbsolutePath()); e.printStackTrace(); } } else { @@ -58,8 +56,7 @@ public class FileServerHandler extends } ctx.flush(); } - - + @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { cause.printStackTrace(); diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java index 5d5eda006e..302411672c 100755 --- a/core/src/main/java/spark/network/netty/PathResolver.java +++ b/core/src/main/java/spark/network/netty/PathResolver.java @@ -1,12 +1,12 @@ package spark.network.netty; + public interface PathResolver { /** * Get the absolute path of the file - * + * * @param fileId * @return the absolute path of file */ public String getAbsolutePath(String fileId); - } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 82bcbd5bc2..be33d4260e 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -288,7 +288,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt val pResolver = new PathResolver { - def getAbsolutePath(blockId:String):String = { + override def getAbsolutePath(blockId: String): String = { if (!blockId.startsWith("shuffle_")) { return null } @@ -298,7 +298,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) shuffleSender = new Thread { override def run() = { val sender = new ShuffleSender(port,pResolver) - logInfo("created ShuffleSender binding to port : "+ port) + logInfo("Created ShuffleSender binding to port : "+ port) sender.start } } -- cgit v1.2.3 From 547dcbe494ce7a888f636cf2596243be37b567b1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 18:39:33 -0700 Subject: Cleaned up Scala files in network/netty from Shane's PR. --- .../scala/spark/network/netty/ShuffleCopier.scala | 50 ++++++++++------------ .../scala/spark/network/netty/ShuffleSender.scala | 18 +++++--- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index d8d35bfeec..a91f5a886d 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -1,23 +1,21 @@ package spark.network.netty +import java.util.concurrent.Executors + import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelInboundByteHandlerAdapter import io.netty.util.CharsetUtil -import java.util.concurrent.atomic.AtomicInteger -import java.util.logging.Logger import spark.Logging import spark.network.ConnectionManagerId -import java.util.concurrent.Executors + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, - blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { - val handler = new ShuffleClientHandler(resultCollectCallback) + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) fc.init() fc.connect(cmId.host, cmId.port) @@ -28,29 +26,28 @@ private[spark] class ShuffleCopier extends Logging { def getBlocks(cmId: ConnectionManagerId, blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + resultCollectCallback: (String, Long, ByteBuf) => Unit) { - blocks.map { - case(blockId,size) => { - getBlock(cmId,blockId,resultCollectCallback) - } + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) } } } -private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging { - - def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } -} private[spark] object ShuffleCopier extends Logging { - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = { - logInfo("File: " + blockId + " content is : \" " - + content.toString(CharsetUtil.UTF_8) + "\"") + private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + } + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } def runGetBlock(host:String, port:Int, file:String){ @@ -71,18 +68,17 @@ private[spark] object ShuffleCopier extends Logging { val host = args(0) val port = args(1).toInt val file = args(2) - val threads = if (args.length>3) args(3).toInt else 10 + val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0,threads)){ + for (i <- Range(0, threads)) { val runnable = new Runnable() { def run() { - runGetBlock(host,port,file) + runGetBlock(host, port, file) } } copiers.execute(runnable) } copiers.shutdown } - } diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala index c1986812e9..dc87fefc56 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -1,12 +1,13 @@ package spark.network.netty -import spark.Logging import java.io.File +import spark.Logging + -private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging { +private[spark] class ShuffleSender(val port: Int, val pResolver: PathResolver) extends Logging { val server = new FileServer(pResolver) - + Runtime.getRuntime().addShutdownHook( new Thread() { override def run() { @@ -20,17 +21,22 @@ private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) ex } } + private[spark] object ShuffleSender { + def main(args: Array[String]) { if (args.length < 3) { - System.err.println("Usage: ShuffleSender ") + System.err.println( + "Usage: ShuffleSender ") System.exit(1) } + val port = args(0).toInt val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2) map {new File(_)} + val localDirs = args.drop(2).map(new File(_)) + val pResovler = new PathResolver { - def getAbsolutePath(blockId:String):String = { + override def getAbsolutePath(blockId: String): String = { if (!blockId.startsWith("shuffle_")) { throw new Exception("Block " + blockId + " is not a shuffle block") } -- cgit v1.2.3 From 8388e8dd7ab5e55ea67b329d9359ba2147d796b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 18:40:35 -0700 Subject: Minor style fix in DiskStore... --- core/src/main/scala/spark/storage/DiskStore.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index be33d4260e..933eeaa216 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -283,7 +283,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) }) } - private def startShuffleBlockSender (){ + private def startShuffleBlockSender() { try { val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt @@ -297,7 +297,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } shuffleSender = new Thread { override def run() = { - val sender = new ShuffleSender(port,pResolver) + val sender = new ShuffleSender(port, pResolver) logInfo("Created ShuffleSender binding to port : "+ port) sender.start } -- cgit v1.2.3 From 5d70ee4663b5589611aef107f914a94f301c7d2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 22:42:15 -0700 Subject: Cleaned up connection manager (moved many classes to their own files). --- .../main/scala/spark/network/BufferMessage.scala | 94 +++++++++++ core/src/main/scala/spark/network/Connection.scala | 137 +++++++++------- .../scala/spark/network/ConnectionManager.scala | 53 +++--- .../scala/spark/network/ConnectionManagerId.scala | 21 +++ core/src/main/scala/spark/network/Message.scala | 179 ++------------------- .../main/scala/spark/network/MessageChunk.scala | 25 +++ .../scala/spark/network/MessageChunkHeader.scala | 58 +++++++ 7 files changed, 315 insertions(+), 252 deletions(-) create mode 100644 core/src/main/scala/spark/network/BufferMessage.scala create mode 100644 core/src/main/scala/spark/network/ConnectionManagerId.scala create mode 100644 core/src/main/scala/spark/network/MessageChunk.scala create mode 100644 core/src/main/scala/spark/network/MessageChunkHeader.scala diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala new file mode 100644 index 0000000000..7b0e489a6c --- /dev/null +++ b/core/src/main/scala/spark/network/BufferMessage.scala @@ -0,0 +1,94 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import spark.storage.BlockManager + + +private[spark] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 00a0433a44..6e28f677a3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,12 +13,13 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) + extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - )) + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) } channel.configureBlocking(false) @@ -32,17 +33,19 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - + // Read channels typically do not register for write and write does not for read // Now, we do have write registering for read too (temporarily), but this is to detect // channel close NOT to actually read/consume data on it ! // How does this work if/when we move to SSL ? - + // What is the interest to register with selector for when we want this connection to be selected def registerInterest() - // What is the interest to register with selector for when we want this connection to be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be - // SelectionKey.OP_READ (until we fix it properly) + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) def unregisterInterest() // On receiving a read event, should we change the interest for this channel or not ? @@ -64,12 +67,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Returns whether we have to register for further reads or not. def read(): Boolean = { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) } // Returns whether we have to register for further writes or not. def write(): Boolean = { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) } def close() { @@ -81,11 +86,17 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, callOnCloseCallback() } - def onClose(callback: Connection => Unit) {onCloseCallback = callback} + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } - def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } - def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } def callOnExceptionCallback(e: Exception) { if (onExceptionCallback != null) { @@ -95,7 +106,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, " and OnExceptionCallback not registered", e) } } - + def callOnCloseCallback() { if (onCloseCallback != null) { onCloseCallback(this) @@ -132,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, print(" (" + position + ", " + length + ")") buffer.position(curPosition) } - } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) -extends Connection(SocketChannel.open, selector_, remoteId_) { +private[spark] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) + extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 + val defaultChunkSize = 65536 //32768 //16384 var nextMessageToBeUsed = 0 def addMessage(message: Message) { - messages.synchronized{ + messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") } } @@ -174,7 +186,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { message.started = true message.startTime = System.currentTimeMillis } - return chunk + return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis @@ -185,7 +197,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } None } - + private def getChunkRR(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { @@ -197,12 +209,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") - return chunk + logTrace( + "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") + return chunk } else { message.finishTime = System.currentTimeMillis logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + @@ -213,7 +227,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { None } } - + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() @@ -228,11 +242,11 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) } - + override def unregisterInterest() { changeConnectionKeyInterest(DEFAULT_INTEREST) } - + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) @@ -262,12 +276,14 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // selection - though need not necessarily always complete successfully. val connected = channel.finishConnect if (!force && !connected) { - logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") return false } // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") @@ -283,13 +299,13 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } override def write(): Boolean = { - try{ - while(true) { + try { + while (true) { if (currentBuffers.size == 0) { outbox.synchronized { outbox.getChunk() match { case Some(chunk) => { - currentBuffers ++= chunk.buffers + currentBuffers ++= chunk.buffers } case None => { // changeConnectionKeyInterest(0) @@ -299,7 +315,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } - + if (currentBuffers.size > 0) { val buffer = currentBuffers(0) val remainingBytes = buffer.remaining @@ -314,7 +330,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -336,7 +352,8 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => @@ -355,30 +372,32 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) -extends Connection(channel_, selector_) { - +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) + extends Connection(channel_, selector_) { + class Inbox() { val messages = new HashMap[Int, BufferMessage]() - + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - + def createNewMessage: BufferMessage = { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } - + val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } - + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) + messages.get(chunk.header.id) } def removeMessage(message: Message) { @@ -387,12 +406,14 @@ extends Connection(channel_, selector_) { } @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { val currId = inferredRemoteManagerId if (currId != null) currId else super.getRemoteConnectionManagerId() } - // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. // We infer that from the messages we receive on the receiver socket. private def processConnectionManagerId(header: MessageChunkHeader) { val currId = inferredRemoteManagerId @@ -428,7 +449,8 @@ extends Connection(channel_, selector_) { } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() @@ -451,9 +473,9 @@ extends Connection(channel_, selector_) { case _ => throw new Exception("Message of unknown type received") } } - + if (currentChunk == null) throw new Exception("No message chunk to receive data") - + val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { // re-register for read event ... @@ -464,14 +486,15 @@ extends Connection(channel_, selector_) { } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - + if (currentChunk.buffer.remaining == 0) { /*println("Filled buffer at " + System.currentTimeMillis)*/ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -481,7 +504,7 @@ extends Connection(channel_, selector_) { } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() @@ -491,7 +514,7 @@ extends Connection(channel_, selector_) { // should not happen - to keep scala compiler happy return true } - + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} override def changeInterestForRead(): Boolean = true @@ -505,7 +528,7 @@ extends Connection(channel_, selector_) { // it does - so let us keep it for now. changeConnectionKeyInterest(SelectionKey.OP_READ) } - + override def unregisterInterest() { changeConnectionKeyInterest(0) } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 0eb03630d0..624a094856 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -18,20 +18,7 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future} import akka.util.Duration import akka.util.duration._ -private[spark] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - def toSocketAddress() = new InetSocketAddress(host, port) -} - -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} - private[spark] class ConnectionManager(port: Int) extends Logging { class MessageStatus( @@ -45,7 +32,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - + private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( @@ -80,7 +67,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) + serverChannel.socket.setReceiveBufferSize(256 * 1024) serverChannel.socket.bind(new InetSocketAddress(port)) serverChannel.register(selector, SelectionKey.OP_ACCEPT) @@ -351,7 +338,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { case e: Exception => logError("Error in select loop", e) } } - + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] @@ -463,7 +450,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { @@ -483,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId + case Some(status) => { + messageStatuses -= bufferMessage.ackId status } - case None => { + case None => { throw new Exception("Could not find reference for received ack message " + message.id) null } @@ -507,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logDebug("Not calling back as callback is null") None } - + if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) @@ -517,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - sendMessage(connectionManagerId, ackMessage.getOrElse { + sendMessage(connectionManagerId, ackMessage.getOrElse { Message.createBufferMessage(bufferMessage.id) }) } @@ -588,17 +575,17 @@ private[spark] object ConnectionManager { def main(args: Array[String]) { val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - + /*testSequentialSending(manager)*/ /*System.gc()*/ /*testParallelSending(manager)*/ /*System.gc()*/ - + /*testParallelDecreasingSending(manager)*/ /*System.gc()*/ @@ -610,9 +597,9 @@ private[spark] object ConnectionManager { println("--------------------------") println("Sequential Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 - + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -628,7 +615,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) @@ -643,12 +630,12 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) + println("Started at " + startTime + ", finished at " + finishTime) println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() @@ -658,7 +645,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Parallel Decreasing Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) buffers.foreach(_.flip) @@ -673,7 +660,7 @@ private[spark] object ConnectionManager { if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis - + val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") @@ -687,7 +674,7 @@ private[spark] object ConnectionManager { println("--------------------------") println("Continuous Sending") println("--------------------------") - val size = 10 * 1024 * 1024 + val size = 10 * 1024 * 1024 val count = 10 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala new file mode 100644 index 0000000000..b554e84251 --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala @@ -0,0 +1,21 @@ +package spark.network + +import java.net.InetSocketAddress + +import spark.Utils + + +private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + def toSocketAddress() = new InetSocketAddress(host, port) +} + + +private[spark] object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 34fac9e776..d4f03610eb 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -1,56 +1,10 @@ package spark.network -import spark._ - -import scala.collection.mutable.ArrayBuffer - import java.nio.ByteBuffer -import java.net.InetAddress import java.net.InetSocketAddress -import storage.BlockManager - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} -private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } +import scala.collection.mutable.ArrayBuffer - override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" -} private[spark] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -59,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var finishTime = -1L def size: Int - + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - + def timeTaken(): String = (finishTime - startTime).toString + " ms" override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } -private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) -extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} private[spark] object Message { val BUFFER_MESSAGE = 1111111111L @@ -181,14 +31,16 @@ private[spark] object Message { def getNewId() = synchronized { lastId += 1 - if (lastId == 0) lastId += 1 + if (lastId == 0) { + lastId += 1 + } lastId } def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { if (dataBuffers == null) { return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } + } if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } @@ -197,7 +49,7 @@ private[spark] object Message { def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = createBufferMessage(dataBuffers, 0) - + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) @@ -205,15 +57,18 @@ private[spark] object Message { return createBufferMessage(Array(dataBuffer), ackId) } } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) + + def createBufferMessage(ackId: Int): BufferMessage = { + createBufferMessage(new Array[ByteBuffer](0), ackId) + } def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + case BUFFER_MESSAGE => new BufferMessage(header.id, + ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } newMessage.senderAddress = header.address newMessage diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala new file mode 100644 index 0000000000..aaf9204d0e --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunk.scala @@ -0,0 +1,25 @@ +package spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + + +private[network] +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + + val size = if (buffer == null) 0 else buffer.remaining + + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = { + "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" + } +} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala new file mode 100644 index 0000000000..3693d509d6 --- /dev/null +++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala @@ -0,0 +1,58 @@ +package spark.network + +import java.net.InetAddress +import java.net.InetSocketAddress +import java.nio.ByteBuffer + + +private[spark] class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. + // Refer to network.Connection + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + + +private[spark] object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} -- cgit v1.2.3 From 9cafacf32ddb9a3f6c5cb774e4fe527225273f16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 7 May 2013 22:42:37 -0700 Subject: Added test for Netty suite. --- core/src/test/scala/spark/DistributedSuite.scala | 3 ++- core/src/test/scala/spark/ShuffleNettySuite.scala | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/spark/ShuffleNettySuite.scala diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4df3bb5b67..488c70c414 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -18,7 +18,8 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.{GetBlock, BlockManagerWorker, StorageLevel} -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala new file mode 100644 index 0000000000..bfaffa953e --- /dev/null +++ b/core/src/test/scala/spark/ShuffleNettySuite.scala @@ -0,0 +1,17 @@ +package spark + +import org.scalatest.BeforeAndAfterAll + + +class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. + + override def beforeAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "true") + } + + override def afterAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "false") + } +} -- cgit v1.2.3 From 0ab818d50812f312596170b5e42aa76d2ff59d15 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 9 May 2013 00:38:59 -0700 Subject: fix linebreak --- .../scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 955ee5d806..170ede0f44 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -31,8 +31,8 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse( throw new IllegalArgumentException("must supply spark home for spark standalone")) - val appDesc = - new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, sc.ui.appUIAddress) + val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, + sc.ui.appUIAddress) client = new Client(sc.env.actorSystem, master, appDesc, this) client.start() -- cgit v1.2.3 From b05c9d22d70333924b988b2dfa359ce3e11f7c9d Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 9 May 2013 18:49:12 +0530 Subject: Remove explicit hardcoding of yarn-standalone as args(0) if it is missing. --- .../scala/spark/deploy/yarn/ApplicationMaster.scala | 19 +++---------------- .../deploy/yarn/ApplicationMasterArguments.scala | 1 - .../scala/spark/deploy/yarn/ClientArguments.scala | 1 - 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala index ae719267e8..aa72c1e5fe 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -148,22 +148,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e .getMethod("main", classOf[Array[String]]) val t = new Thread { override def run() { - var mainArgs: Array[String] = null - var startIndex = 0 - - // I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now ! - if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") { - // ensure that first param is ALWAYS "yarn-standalone" - mainArgs = new Array[String](args.userArgs.size() + 1) - mainArgs.update(0, "yarn-standalone") - startIndex = 1 - } - else { - mainArgs = new Array[String](args.userArgs.size()) - } - - args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size()) - + // Copy + var mainArgs: Array[String] = new Array[String](args.userArgs.size()) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size()) mainMethod.invoke(null, mainArgs) } } diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala index dc89125d81..1b00208511 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -69,7 +69,6 @@ class ApplicationMasterArguments(val args: Array[String]) { " --class CLASS_NAME Name of your application's main class (required)\n" + " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + - " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + " --num-workers NUM Number of workers to start (Default: 2)\n" + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala index 2e69fe3fb0..24110558e7 100644 --- a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -92,7 +92,6 @@ class ClientArguments(val args: Array[String]) { " --class CLASS_NAME Name of your application's main class (required)\n" + " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + - " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + " --num-workers NUM Number of workers to start (Default: 2)\n" + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + -- cgit v1.2.3 From 012c9e5ab072239e07202abe4775b434be6e32b9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 9 May 2013 14:20:01 -0700 Subject: Revert "Merge pull request #596 from esjewett/master" because the dependency on hbase introduces netty-3.2.2 which conflicts with netty-3.5.3 already in Spark. This caused multiple test failures. This reverts commit 0f1b7a06e1f6782711170234f105f1b277e3b04c, reversing changes made to aacca1b8a85bd073ce185a06d6470b070761b2f4. --- .../src/main/scala/spark/examples/HBaseTest.scala | 35 ---------------------- project/SparkBuild.scala | 6 +--- 2 files changed, 1 insertion(+), 40 deletions(-) delete mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala deleted file mode 100644 index 9bad876860..0000000000 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ /dev/null @@ -1,35 +0,0 @@ -package spark.examples - -import spark._ -import spark.rdd.NewHadoopRDD -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} -import org.apache.hadoop.hbase.client.HBaseAdmin -import org.apache.hadoop.hbase.mapreduce.TableInputFormat - -object HBaseTest { - def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HBaseTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val conf = HBaseConfiguration.create() - - // Other options for configuring scan behavior are available. More information available at - // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html - conf.set(TableInputFormat.INPUT_TABLE, args(1)) - - // Initialize hBase table if necessary - val admin = new HBaseAdmin(conf) - if(!admin.isTableAvailable(args(1))) { - val tableDesc = new HTableDescriptor(args(1)) - admin.createTable(tableDesc) - } - - val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], - classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], - classOf[org.apache.hadoop.hbase.client.Result]) - - hBaseRDD.count() - - System.exit(0) - } -} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6f5607d31c..190d723435 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -200,11 +200,7 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), - libraryDependencies ++= Seq( - "com.twitter" % "algebird-core_2.9.2" % "0.1.11", - "org.apache.hbase" % "hbase" % "0.94.6" - ) + libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From ee6f6aa6cd028e6a3938dcd5334661c27f493bc6 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Thu, 9 May 2013 18:33:38 -0500 Subject: Add hBase example --- .../src/main/scala/spark/examples/HBaseTest.scala | 35 ++++++++++++++++++++++ project/SparkBuild.scala | 6 +++- 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..6e910154d4 --- /dev/null +++ b/examples/src/main/scala/spark/examples/HBaseTest.scala @@ -0,0 +1,35 @@ +package spark.examples + +import spark._ +import spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase table if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val tableDesc = new HTableDescriptor(args(1)) + admin.createTable(tableDesc) + } + + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result]) + + hBaseRDD.count() + + System.exit(0) + } +} \ No newline at end of file diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 190d723435..57fe04ea2d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -200,7 +200,11 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") + resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), + libraryDependencies ++= Seq( + "com.twitter" % "algebird-core_2.9.2" % "0.1.11", + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty) + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From d761e7359deb7ca864d33b8f2e4380b57448630b Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:05:10 -0600 Subject: adding kafkaStream API tests --- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 4 ++-- .../src/test/scala/spark/streaming/InputStreamsSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 61e4c0a207..350d0888a3 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -4,6 +4,7 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.io.Files; +import kafka.serializer.StringDecoder; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.junit.After; import org.junit.Assert; @@ -1203,8 +1204,7 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap topics = Maps.newHashMap(); JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics); - JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 1024d3ac97..595c766a21 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -240,6 +240,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i) === expectedOutput(i)) } } + + test("kafka input stream") { + val ssc = new StreamingContext(master, framework, batchDuration) + val topics = Map("my-topic" -> 1) + val test1 = ssc.kafkaStream("localhost:12345", "group", topics) + val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK) + + // Test specifying decoder + val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group") + val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) + } } -- cgit v1.2.3 From b95c1bdbbaeea86152e24b394a03bbbad95989d5 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 12:47:24 -0600 Subject: count() now uses a transform instead of ConstantInputDStream --- streaming/src/main/scala/spark/streaming/DStream.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e3a9247924..e125310861 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,10 +441,7 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { - val zero = new ConstantInputDStream(context, context.sparkContext.makeRDD(Seq((null, 0L)), 1)) - this.map(_ => (null, 1L)).union(zero).reduceByKey(_ + _).map(_._2) - } + def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) /** * Return a new DStream in which each RDD contains the counts of each distinct value in -- cgit v1.2.3 From 6e6b3e0d7eadab97d45e975452c7e0c18246686e Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Fri, 10 May 2013 13:02:34 -0700 Subject: Actually use the cleaned closure in foreachPartition --- core/src/main/scala/spark/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index fd14ef17f1..dde131696f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -489,7 +489,7 @@ abstract class RDD[T: ClassManifest]( */ def foreachPartition(f: Iterator[T] => Unit) { val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** -- cgit v1.2.3 From 3632980b1b61dbb9ab9a3ab3d92fb415cb7173b9 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 15:54:26 -0600 Subject: fixing indentation --- .../src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 13427873ff..4ad2bdf8a8 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -105,7 +105,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( - kafkaParams: JMap[String, String], + kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { -- cgit v1.2.3 From f25282def5826fab6caabff28c82c57a7f3fdcb8 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 10 May 2013 17:34:28 -0600 Subject: fixing kafkaStream Java API and adding test --- .../scala/spark/streaming/api/java/JavaStreamingContext.scala | 10 +++++++--- streaming/src/test/java/spark/streaming/JavaAPISuite.java | 6 ++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 4ad2bdf8a8..b35d9032f1 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -99,18 +99,22 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. + * @param typeClass Type of RDD + * @param decoderClass Type of kafka decoder * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only */ - def kafkaStream[T, D <: kafka.serializer.Decoder[_]: Manifest]( + def kafkaStream[T, D <: kafka.serializer.Decoder[_]]( + typeClass: Class[T], + decoderClass: Class[D], kafkaParams: JMap[String, String], topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 350d0888a3..e5fdbe1b7a 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -1206,6 +1206,12 @@ public class JavaAPISuite implements Serializable { JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK()); + + HashMap kafkaParams = Maps.newHashMap(); + kafkaParams.put("zk.connect","localhost:12345"); + kafkaParams.put("groupid","consumer-group"); + JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics, + StorageLevel.MEMORY_AND_DISK()); } @Test -- cgit v1.2.3 From ee37612bc95e8486fa328908005293585912db71 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sat, 11 May 2013 11:12:22 +0530 Subject: 1) Add support for HADOOP_CONF_DIR (and/or YARN_CONF_DIR - use either) : which is used to specify the client side configuration directory : which needs to be part of the CLASSPATH. 2) Move from var+=".." to var="$var.." : the former does not work on older bash shells unfortunately. --- docs/running-on-yarn.md | 3 +++ run | 65 +++++++++++++++++++++++++++++-------------------- run2.cmd | 13 ++++++++++ 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 26424bbe52..c8cf8ffc35 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -30,6 +30,9 @@ If you want to test out the YARN deployment mode, you can use the current Spark # Launching Spark on YARN +Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. +This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. + The command to launch the YARN Client is as follows: SPARK_JAR= ./run spark.deploy.yarn.Client \ diff --git a/run b/run index 0a58ac4a36..c744bbd3dc 100755 --- a/run +++ b/run @@ -22,7 +22,7 @@ fi # values for that; it doesn't need a lot if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} - SPARK_DAEMON_JAVA_OPTS+=" -Dspark.akka.logLifecycleEvents=true" + SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default fi @@ -30,19 +30,19 @@ fi # Add java opts for master, worker, executor. The opts maybe null case "$1" in 'spark.deploy.master.Master') - SPARK_JAVA_OPTS+=" $SPARK_MASTER_OPTS" + SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_MASTER_OPTS" ;; 'spark.deploy.worker.Worker') - SPARK_JAVA_OPTS+=" $SPARK_WORKER_OPTS" + SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_WORKER_OPTS" ;; 'spark.executor.StandaloneExecutorBackend') - SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS" + SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" ;; 'spark.executor.MesosExecutorBackend') - SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS" + SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" ;; 'spark.repl.Main') - SPARK_JAVA_OPTS+=" $SPARK_REPL_OPTS" + SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS" ;; esac @@ -85,11 +85,11 @@ export SPARK_MEM # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$SPARK_JAVA_OPTS" -JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH" -JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM" +JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH" +JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e $FWDIR/conf/java-opts ] ; then - JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`" + JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" fi export JAVA_OPTS @@ -110,30 +110,30 @@ fi # Build up classpath CLASSPATH="$SPARK_CLASSPATH" -CLASSPATH+=":$FWDIR/conf" -CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$FWDIR/conf" +CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes" if [ -n "$SPARK_TESTING" ] ; then - CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" fi -CLASSPATH+=":$CORE_DIR/src/main/resources" -CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH+=":$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar +CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources" +CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar if [ -e "$FWDIR/lib_managed" ]; then - CLASSPATH+=":$FWDIR/lib_managed/jars/*" - CLASSPATH+=":$FWDIR/lib_managed/bundles/*" + CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*" + CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" fi -CLASSPATH+=":$REPL_DIR/lib/*" +CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" if [ -e $REPL_BIN_DIR/target ]; then for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do - CLASSPATH+=":$jar" + CLASSPATH="$CLASSPATH:$jar" done fi -CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do - CLASSPATH+=":$jar" + CLASSPATH="$CLASSPATH:$jar" done # Figure out the JAR file that our examples were packaged into. This includes a bit of a hack @@ -147,6 +147,17 @@ if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar` fi +# Add hadoop conf dir - else FileSystem.*, etc fail ! +# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +# the configurtion files. +if [ "x" != "x$HADOOP_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" +fi +if [ "x" != "x$YARN_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" +fi + + # Figure out whether to run our class with java or with the scala launcher. # In most cases, we'd prefer to execute our process with java because scala # creates a shell script as the parent of its Java process, which makes it @@ -156,9 +167,9 @@ fi if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS else - CLASSPATH+=":$SCALA_LIBRARY_PATH/scala-library.jar" - CLASSPATH+=":$SCALA_LIBRARY_PATH/scala-compiler.jar" - CLASSPATH+=":$SCALA_LIBRARY_PATH/jline.jar" + CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar" + CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar" + CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar" # The JVM doesn't read JAVA_OPTS by default so we need to pass it in EXTRA_ARGS="$JAVA_OPTS" fi diff --git a/run2.cmd b/run2.cmd index d2d4807971..c6f43dde5b 100644 --- a/run2.cmd +++ b/run2.cmd @@ -63,6 +63,19 @@ set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\* set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\* set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes +rem Add hadoop conf dir - else FileSystem.*, etc fail +rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +rem the configurtion files. +if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir + set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% +:no_hadoop_conf_dir + +if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir + set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% +:no_yarn_conf_dir + + + rem Figure out the JAR file that our examples were packaged into. rem First search in the build path from SBT: for %%d in ("examples/target/scala-%SCALA_VERSION%/spark-examples*.jar") do ( -- cgit v1.2.3 From 0345954530a445b275595962c9f949cad55a01f6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 11 May 2013 14:17:09 -0700 Subject: SPARK-738: Spark should detect and squash nonserializable exceptions --- core/src/main/scala/spark/executor/Executor.scala | 16 ++++++++++++++-- core/src/test/scala/spark/DistributedSuite.scala | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 344face5e6..f9061b1c71 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -1,6 +1,6 @@ package spark.executor -import java.io.{File, FileOutputStream} +import java.io.{NotSerializableException, File, FileOutputStream} import java.net.{URI, URL, URLClassLoader} import java.util.concurrent._ @@ -123,7 +123,19 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert case t: Throwable => { val reason = ExceptionFailure(t) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + val serReason = + try { + ser.serialize(reason) + } + catch { + case e: NotSerializableException => { + val message = "Spark caught unserializable exn: " + t.toString + val throwable = new Exception(message) + throwable.setStackTrace(t.getStackTrace) + ser.serialize(new ExceptionFailure(throwable)) + } + } + context.statusUpdate(taskId, TaskState.FAILED, serReason) // TODO: Should we exit the whole executor here? On the one hand, the failed task may // have left some weird state around depending on when the exception was thrown, but on diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4df3bb5b67..8ab0f2cfa2 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -18,6 +18,9 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.{GetBlock, BlockManagerWorker, StorageLevel} +class NotSerializableClass +class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} + class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" @@ -27,6 +30,24 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter System.clearProperty("spark.storage.memoryFraction") } + test("task throws not serializable exception") { + // Ensures that executors do not crash when an exn is not serializable. If executors crash, + // this test will hang. Correct behavior is that executors don't crash but fail tasks + // and the scheduler throws a SparkException. + + // numSlaves must be less than numPartitions + val numSlaves = 3 + val numPartitions = 10 + + sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + val data = sc.parallelize(1 to 100, numPartitions).map(x => (x, x)). + map(x => throw new NotSerializableExn(new NotSerializableClass)) + intercept[SparkException] { + data.count() + } + resetSparkContext() + } + test("local-cluster format") { sc = new SparkContext("local-cluster[2,1,512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) -- cgit v1.2.3 From a5c28bb888f74d27893c198865f588ca0334a8a6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 11 May 2013 14:20:39 -0700 Subject: Removing unnecessary map --- core/src/test/scala/spark/DistributedSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 8ab0f2cfa2..33c99471c6 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -40,7 +40,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val numPartitions = 10 sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") - val data = sc.parallelize(1 to 100, numPartitions).map(x => (x, x)). + val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { data.count() -- cgit v1.2.3 From 440719109e10ea1cc6149a8f61d42ea7cc443352 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 May 2013 18:27:26 -0700 Subject: Throw exception if task result exceeds Akka frame size. This partially addresses SPARK-747. --- core/src/main/scala/spark/executor/Executor.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 344face5e6..718f0ff5bc 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -72,6 +72,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Initialize Spark environment (using system properties read above) val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) SparkEnv.set(env) + private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") // Start worker thread pool val threadPool = new ThreadPoolExecutor( @@ -113,6 +114,9 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedResult = ser.serialize(result) logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) + if (serializedResult.limit >= (akkaFrameSize - 1024)) { + throw new SparkException("Result for " + taskId + " exceeded Akka frame size") + } context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) } catch { -- cgit v1.2.3 From 3da2305ed0d4add7127953e5240632f86053b4aa Mon Sep 17 00:00:00 2001 From: Cody Koeninger Date: Sat, 11 May 2013 23:59:07 -0500 Subject: code cleanup per rxin comments --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 67 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala index 4c3054465c..b0f7054233 100644 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -5,23 +5,27 @@ import java.sql.{Connection, ResultSet} import spark.{Logging, Partition, RDD, SparkContext, TaskContext} import spark.util.NextIterator +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + /** - An RDD that executes an SQL query on a JDBC connection and reads results. - @param getConnection a function that returns an open Connection. - The RDD takes care of closing the connection. - @param sql the text of the query. - The query must contain two ? placeholders for parameters used to partition the results. - E.g. "select title, author from books where ? <= id and id <= ?" - @param lowerBound the minimum value of the first placeholder - @param upperBound the maximum value of the second placeholder - The lower and upper bounds are inclusive. - @param numPartitions the number of partitions. - Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, - the query would be executed twice, once with (1, 10) and once with (11, 20) - @param mapRow a function from a ResultSet to a single row of the desired result type(s). - This should only call getInt, getString, etc; the RDD takes care of calling next. - The default maps a ResultSet to an array of Object. -*/ + * An RDD that executes an SQL query on a JDBC connection and reads results. + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ class JdbcRDD[T: ClassManifest]( sc: SparkContext, getConnection: () => Connection, @@ -29,26 +33,33 @@ class JdbcRDD[T: ClassManifest]( lowerBound: Long, upperBound: Long, numPartitions: Int, - mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray) + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) extends RDD[T](sc, Nil) with Logging { - override def getPartitions: Array[Partition] = - ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions). - filter(! _.isEmpty). - zipWithIndex. - map(x => new JdbcPartition(x._2, x._1.head, x._1.last)). - toArray + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartition(i, start, end) + }).toArray + } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { context.addOnCompleteCallback{ () => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - // force mysql driver to stream rather than pull entire resultset into memory + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { stmt.setFetchSize(Integer.MIN_VALUE) logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") } + stmt.setLong(1, part.lower) stmt.setLong(2, part.upper) val rs = stmt.executeQuery() @@ -81,14 +92,10 @@ class JdbcRDD[T: ClassManifest]( } } } - -} - -private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx } object JdbcRDD { - val resultSetToObjectArray = (rs: ResultSet) => + def resultSetToObjectArray(rs: ResultSet) = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } } -- cgit v1.2.3 From 059ab8875463ab22fe329fb6a627cac0a7d8158c Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 11 May 2013 23:39:14 -0700 Subject: Changing technique to use same code path in all cases --- core/src/main/scala/spark/TaskEndReason.scala | 13 ++++++++++--- core/src/main/scala/spark/executor/Executor.scala | 16 ++-------------- .../scala/spark/scheduler/cluster/TaskSetManager.scala | 8 ++++---- .../scala/spark/scheduler/local/LocalScheduler.scala | 3 ++- 4 files changed, 18 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala index 420c54bc9a..ce9bb49897 100644 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -14,9 +14,16 @@ private[spark] case object Success extends TaskEndReason private[spark] case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it -private[spark] -case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason +private[spark] case class FetchFailed( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int) + extends TaskEndReason -private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason +private[spark] case class ExceptionFailure( + description: String, + stackTrace: Array[StackTraceElement]) + extends TaskEndReason private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index f9061b1c71..9084def9b2 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -122,20 +122,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert } case t: Throwable => { - val reason = ExceptionFailure(t) - val serReason = - try { - ser.serialize(reason) - } - catch { - case e: NotSerializableException => { - val message = "Spark caught unserializable exn: " + t.toString - val throwable = new Exception(message) - throwable.setStackTrace(t.getStackTrace) - ser.serialize(new ExceptionFailure(throwable)) - } - } - context.statusUpdate(taskId, TaskState.FAILED, serReason) + val reason = ExceptionFailure(t.toString, t.getStackTrace) + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may // have left some weird state around depending on when the exception was thrown, but on diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 27e713e2c4..6d663de2f8 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -493,7 +493,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe return case ef: ExceptionFailure => - val key = ef.exception.toString + val key = ef.description val now = System.currentTimeMillis val (printFull, dupCount) = { if (recentExceptions.contains(key)) { @@ -511,10 +511,10 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } } if (printFull) { - val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s".format(ef.description, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) } case _ => {} diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index f060a940a9..42d5bc4813 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -102,7 +102,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } else { // TODO: Do something nicer here to return all the way to the user if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null) + listener.taskEnded( + task, new ExceptionFailure(t.getMessage, t.getStackTrace), null, null, info, null) } } } -- cgit v1.2.3 From 1c15b8505124c157449b6d41e1127f3eb4081a23 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 11 May 2013 23:52:53 -0700 Subject: Removing import --- core/src/main/scala/spark/executor/Executor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9084def9b2..1d5516966d 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -1,6 +1,6 @@ package spark.executor -import java.io.{NotSerializableException, File, FileOutputStream} +import java.io.{File, FileOutputStream} import java.net.{URI, URL, URLClassLoader} import java.util.concurrent._ -- cgit v1.2.3 From 72b9c4cb6ec4080eb8751e5e040f180272ac82a6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 11 May 2013 23:53:50 -0700 Subject: Small fix --- core/src/main/scala/spark/scheduler/local/LocalScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 42d5bc4813..a357422466 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -103,7 +103,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // TODO: Do something nicer here to return all the way to the user if (!Thread.currentThread().isInterrupted) listener.taskEnded( - task, new ExceptionFailure(t.getMessage, t.getStackTrace), null, null, info, null) + task, new ExceptionFailure(t.toString, t.getStackTrace), null, null, info, null) } } } -- cgit v1.2.3 From 7f0833647b784c4ec7cd2f2e8e4fcd5ed6f673cd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 12 May 2013 07:54:03 -0700 Subject: Capturing class name --- core/src/main/scala/spark/TaskEndReason.scala | 1 + core/src/main/scala/spark/executor/Executor.scala | 2 +- core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala | 3 ++- core/src/main/scala/spark/scheduler/local/LocalScheduler.scala | 7 ++++--- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala index ce9bb49897..ca793eb402 100644 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -22,6 +22,7 @@ private[spark] case class FetchFailed( extends TaskEndReason private[spark] case class ExceptionFailure( + className: String, description: String, stackTrace: Array[StackTraceElement]) extends TaskEndReason diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 1d5516966d..da20b84544 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -122,7 +122,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert } case t: Throwable => { - val reason = ExceptionFailure(t.toString, t.getStackTrace) + val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 6d663de2f8..06de3c755e 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -512,7 +512,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } if (printFull) { val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s".format(ef.description, locs.mkString("\n"))) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) } else { logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index a357422466..ebe42685ad 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -101,9 +101,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - if (!Thread.currentThread().isInterrupted) - listener.taskEnded( - task, new ExceptionFailure(t.toString, t.getStackTrace), null, null, info, null) + if (!Thread.currentThread().isInterrupted) { + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) + listener.taskEnded(task, failure, null, null, info, null) + } } } } -- cgit v1.2.3 From b16c4896f617f352bb230908b7c08c7c5b028434 Mon Sep 17 00:00:00 2001 From: Cody Koeninger Date: Tue, 14 May 2013 23:44:04 -0500 Subject: add test for JdbcRDD using embedded derby, per rxin suggestion --- .gitignore | 1 + core/src/test/scala/spark/rdd/JdbcRDDSuite.scala | 56 ++++++++++++++++++++++++ project/SparkBuild.scala | 1 + 3 files changed, 58 insertions(+) create mode 100644 core/src/test/scala/spark/rdd/JdbcRDDSuite.scala diff --git a/.gitignore b/.gitignore index 155e785b01..b87fc1ee79 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ streaming-tests.log dependency-reduced-pom.xml .ensime .ensime_lucene +derby.log diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..6afb0fa9bc --- /dev/null +++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,56 @@ +package spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import spark.SparkContext._ +import spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f0b371b2cf..b11893590e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -147,6 +147,7 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" % "spray-json_2.9.2" % "1.1.1", + "org.apache.derby" % "derby" % "10.4.2.0" % "test", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } -- cgit v1.2.3 From d7d1da79d30961e461115a73bbfc9e4c4448e533 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 13 May 2013 16:37:49 -0700 Subject: when akka starts, use akkas default classloader (current thread) --- core/src/main/scala/spark/util/AkkaUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 9fb7e001ba..cd79bd2bda 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -52,7 +52,7 @@ private[spark] object AkkaUtils { """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, lifecycleEvents, akkaWriteTimeout)) - val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) + val actorSystem = ActorSystem(name, akkaConf) // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // hack because Akka doesn't let you figure out the port through the public API yet. -- cgit v1.2.3 From 38d4b97c6d47df4e1f1a3279ff786509f60e0eaf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 13 May 2013 16:50:43 -0700 Subject: use threads classloader when deserializing task results; classnotfoundexception includes classloader --- .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c69f3bdb7f..b348092d89 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -503,9 +503,16 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( tid, info.duration, tasksFinished, numTasks)) // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + try { + val result = ser.deserialize[TaskResult[_]](serializedData) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread().getContextClassLoader + throw new SparkException("ClassNotFound with classloader: " + loader, cnf) + case ex => throw ex + } // Mark finished and stop if we've finished all the tasks finished(index) = true if (tasksFinished == numTasks) { -- cgit v1.2.3 From 404f9ff617401a2f8d12845861ce8f02cfe6442c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 May 2013 23:28:34 -0700 Subject: Added derby dependency to Maven pom files for the JDBC Java test. --- core/pom.xml | 5 +++++ pom.xml | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/pom.xml b/core/pom.xml index 9a019b5a42..57a95328c3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -92,6 +92,11 @@ log4j + + org.apache.derby + derby + test + org.scalatest scalatest_${scala.version} diff --git a/pom.xml b/pom.xml index 3936165d78..d7cdc591cf 100644 --- a/pom.xml +++ b/pom.xml @@ -256,6 +256,12 @@ mesos ${mesos.version} + + org.apache.derby + derby + 10.4.2.0 + test + org.scala-lang @@ -565,7 +571,7 @@ 2 - 2.0.2-alpha + 2.0.2-alpha -- cgit v1.2.3 From f9d40a5848a2e1eef31ac63cd9221d5b77c1c5a7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 14 May 2013 23:29:57 -0700 Subject: Added a comment in JdbcRDD for example usage. --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala index b0f7054233..a50f407737 100644 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -11,11 +11,13 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e /** * An RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JdbcRDDSuite. + * * @param getConnection a function that returns an open Connection. * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * E.g. "select title, author from books where ? <= id and id <= ?" * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. -- cgit v1.2.3 From afcad7b3aa8736231a526417ede47ce6d353a70c Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 15 May 2013 14:45:14 -0300 Subject: Docs: Mention spark shell's default for MASTER --- docs/scala-programming-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 2315aadbdf..b0da130fcb 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -67,6 +67,8 @@ The master URL passed to Spark can be in one of the following formats: +If no master URL is specified, the spark shell defaults to "local". + For running on YARN, Spark launches an instance of the standalone deploy cluster within YARN; see [running on YARN](running-on-yarn.html) for details. ### Deploying Code on a Cluster -- cgit v1.2.3 From b8e46b6074e5ecc1ae4ed22ea32983597c14b683 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 16 May 2013 01:52:40 -0700 Subject: Abort job if result exceeds Akka frame size; add test. --- core/src/main/scala/spark/TaskEndReason.scala | 2 ++ core/src/main/scala/spark/executor/Executor.scala | 3 ++- .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 6 ++++++ core/src/test/scala/spark/DistributedSuite.scala | 13 +++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala index 420c54bc9a..c5da453562 100644 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -20,3 +20,5 @@ case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, re private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason private[spark] case class OtherFailure(message: String) extends TaskEndReason + +private[spark] case class TaskResultTooBigFailure() extends TaskEndReason \ No newline at end of file diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 718f0ff5bc..9ec4eb6e88 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -115,7 +115,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val serializedResult = ser.serialize(result) logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) if (serializedResult.limit >= (akkaFrameSize - 1024)) { - throw new SparkException("Result for " + taskId + " exceeded Akka frame size") + context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure())) + return } context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 27e713e2c4..df7f0eafff 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -492,6 +492,12 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe sched.taskSetFinished(this) return + case taskResultTooBig: TaskResultTooBigFailure => + logInfo("Loss was due to task %s result exceeding Akka frame size;" + + "aborting job".format(tid)) + abort("Task %s result exceeded Akka frame size".format(tid)) + return + case ef: ExceptionFailure => val key = ef.exception.toString val now = System.currentTimeMillis diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4df3bb5b67..9f58999cbe 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -277,6 +277,19 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } } + + test("job should fail if TaskResult exceeds Akka frame size") { + // We must use local-cluster mode since results are returned differently + // when running under LocalScheduler: + sc = new SparkContext("local-cluster[1,1,512]", "test") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} + val exception = intercept[SparkException] { + rdd.reduce((x, y) => x) + } + exception.getMessage should endWith("result exceeded Akka frame size") + } } object DistributedSuite { -- cgit v1.2.3 From 87540a7b386837d177a6d356ad1f5ef2c1ad6ea5 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 16 May 2013 15:27:58 +0530 Subject: Fix running on yarn documentation --- docs/running-on-yarn.md | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c8cf8ffc35..41c0b235dd 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -11,14 +11,32 @@ Ex: mvn -Phadoop2-yarn clean install # Building spark core consolidated jar. -Currently, only sbt can buid a consolidated jar which contains the entire spark code - which is required for launching jars on yarn. -To do this via sbt - though (right now) is a manual process of enabling it in project/SparkBuild.scala. +We need a consolidated spark core jar (which bundles all the required dependencies) to run Spark jobs on a yarn cluster. +This can be built either through sbt or via maven. + +- Building spark assembled jar via sbt. +It is a manual process of enabling it in project/SparkBuild.scala. Please comment out the HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN variables before the line 'For Hadoop 2 YARN support' Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support. -Currnetly, it is a TODO to add support for maven assembly. +Assembly of the jar Ex: +./sbt/sbt clean assembly + +The assembled jar would typically be something like : +./streaming/target/spark-streaming-.jar + + +- Building spark assembled jar via sbt. +Use the hadoop2-yarn profile and execute the package target. + +Something like this. Ex: +$ mvn -Phadoop2-yarn clean package -DskipTests=true + + +This will build the shaded (consolidated) jar. Typically something like : +./repl-bin/target/spark-repl-bin--shaded-hadoop2-yarn.jar # Preparations @@ -62,6 +80,6 @@ The above starts a YARN Client programs which periodically polls the Application # Important Notes - When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above. -- YARN does not support requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. +- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. - Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster). Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed. -- cgit v1.2.3 From feddd2530ddfac7a01b03c9113b29945ec0e9a82 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 16 May 2013 17:49:14 +0530 Subject: Filter out nulls - prevent NPE --- core/src/main/scala/spark/SparkContext.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 2ae4ad8659..15a75c7e93 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -537,6 +537,8 @@ class SparkContext( * filesystems), or an HTTP, HTTPS or FTP URI. */ def addJar(path: String) { + // weird - debug why this is happening. + if (null == path) return val uri = new URI(path) val key = uri.getScheme match { case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) -- cgit v1.2.3 From f16c781709f9e108d9fe8ac052fb55146ce8a14f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 16 May 2013 17:50:22 +0530 Subject: Fix documentation to use yarn-standalone as master --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 41c0b235dd..2e46ff0ed1 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -69,7 +69,7 @@ For example: SPARK_JAR=./core/target/spark-core-assembly-{{site.SPARK_VERSION}}.jar ./run spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_VERSION}}/spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}.jar \ --class spark.examples.SparkPi \ - --args standalone \ + --args yarn-standalone \ --num-workers 3 \ --master-memory 4g \ --worker-memory 2g \ @@ -79,7 +79,7 @@ The above starts a YARN Client programs which periodically polls the Application # Important Notes -- When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above. +- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above. - We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. - Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster). Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed. -- cgit v1.2.3 From f0881f8d4812dcee955aa303d7a4b76c58b75a61 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 17 May 2013 01:58:50 +0530 Subject: Hope this does not turn into a bike shed change --- core/src/main/scala/spark/SparkContext.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 15a75c7e93..736b5485b7 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -537,8 +537,10 @@ class SparkContext( * filesystems), or an HTTP, HTTPS or FTP URI. */ def addJar(path: String) { - // weird - debug why this is happening. - if (null == path) return + if (null == path) { + logInfo("null specified as parameter to addJar") + return + } val uri = new URI(path) val key = uri.getScheme match { case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) -- cgit v1.2.3 From c6e2770bfe940a4f4f26f75c9ba228faea7316f0 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 17 May 2013 05:10:38 +0800 Subject: Fix ClusterScheduler bug to avoid allocating tasks to same slave --- .../spark/scheduler/cluster/ClusterScheduler.scala | 48 +++++++++++++--------- .../main/scala/spark/scheduler/cluster/Pool.scala | 20 ++++----- .../spark/scheduler/cluster/Schedulable.scala | 3 +- .../spark/scheduler/cluster/TaskSetManager.scala | 8 +++- .../spark/scheduler/ClusterSchedulerSuite.scala | 46 +++++++++++++-------- 5 files changed, 75 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 1a300c9e8c..4caafcc1d3 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -164,27 +164,35 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray - for (i <- 0 until offers.size) { - var launchedTask = true - val execId = offers(i).executorId - val host = offers(i).hostname - while (availableCpus(i) > 0 && launchedTask) { + var launchedTask = false + val sortedLeafSchedulable = rootPool.getSortedLeafSchedulable() + for (schedulable <- sortedLeafSchedulable) + { + logDebug("parentName:%s,name:%s,runningTasks:%s".format(schedulable.parent.name,schedulable.name,schedulable.runningTasks)) + } + for (schedulable <- sortedLeafSchedulable) { + do { launchedTask = false - rootPool.receiveOffer(execId,host,availableCpus(i)) match { - case Some(task) => - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = task.taskSetId - taskSetTaskIds(task.taskSetId) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - - case None => {} - } - } + for (i <- 0 until offers.size) { + var launchedTask = true + val execId = offers(i).executorId + val host = offers(i).hostname + schedulable.slaveOffer(execId,host,availableCpus(i)) match { + case Some(task) => + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = task.taskSetId + taskSetTaskIds(task.taskSetId) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + launchedTask = true + + case None => {} + } + } + } while(launchedTask) } if (tasks.size > 0) { hasLaunchedTask = true diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index d5482f71ad..ae603e7dd9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -75,19 +75,17 @@ private[spark] class Pool( return shouldRevive } - override def receiveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + override def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + return None + } + + override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { + var leafSchedulableQueue = new ArrayBuffer[Schedulable] val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) - for (manager <- sortedSchedulableQueue) { - logInfo("parentName:%s,schedulableName:%s,minShares:%d,weight:%d,runningTasks:%d".format( - manager.parent.name, manager.name, manager.minShare, manager.weight, manager.runningTasks)) + for (schedulable <- sortedSchedulableQueue) { + leafSchedulableQueue ++= schedulable.getSortedLeafSchedulable() } - for (manager <- sortedSchedulableQueue) { - val task = manager.receiveOffer(execId, host, availableCpus) - if (task != None) { - return task - } - } - return None + return leafSchedulableQueue } override def increaseRunningTasks(taskNum: Int) { diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala index 54e8ae95f9..c620588e14 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -21,6 +21,7 @@ private[spark] trait Schedulable { def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable def executorLost(executorId: String, host: String): Unit - def receiveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] + def slaveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] def checkSpeculatableTasks(): Boolean + def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index baaaa41a37..80edbe77a1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -198,7 +198,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - override def receiveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + override def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -398,6 +398,12 @@ private[spark] class TaskSetManager( //nothing } + override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { + var leafSchedulableQueue = new ArrayBuffer[Schedulable] + leafSchedulableQueue += this + return leafSchedulableQueue + } + override def executorLost(execId: String, hostname: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) val newHostsAlive = sched.hostsAlive diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index 2eda48196b..8426be7575 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -6,6 +6,7 @@ import org.scalatest.BeforeAndAfter import spark._ import spark.scheduler._ import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer import java.util.Properties @@ -25,34 +26,34 @@ class DummyTaskSetManager( var numTasks = initNumTasks var tasksFinished = 0 - def increaseRunningTasks(taskNum: Int) { + override def increaseRunningTasks(taskNum: Int) { runningTasks += taskNum if (parent != null) { parent.increaseRunningTasks(taskNum) } } - def decreaseRunningTasks(taskNum: Int) { + override def decreaseRunningTasks(taskNum: Int) { runningTasks -= taskNum if (parent != null) { parent.decreaseRunningTasks(taskNum) } } - def addSchedulable(schedulable: Schedulable) { + override def addSchedulable(schedulable: Schedulable) { } - def removeSchedulable(schedulable: Schedulable) { + override def removeSchedulable(schedulable: Schedulable) { } - def getSchedulableByName(name: String): Schedulable = { + override def getSchedulableByName(name: String): Schedulable = { return null } - def executorLost(executorId: String, host: String): Unit = { + override def executorLost(executorId: String, host: String): Unit = { } - def receiveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] = { + override def slaveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] = { if (tasksFinished + runningTasks < numTasks) { increaseRunningTasks(1) return Some(new TaskDescription(0, stageId.toString, execId, "task 0:0", null)) @@ -60,10 +61,16 @@ class DummyTaskSetManager( return None } - def checkSpeculatableTasks(): Boolean = { + override def checkSpeculatableTasks(): Boolean = { return true } + override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { + var leafSchedulableQueue = new ArrayBuffer[Schedulable] + leafSchedulableQueue += this + return leafSchedulableQueue + } + def taskFinished() { decreaseRunningTasks(1) tasksFinished +=1 @@ -80,16 +87,21 @@ class DummyTaskSetManager( class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { - def receiveOffer(rootPool: Pool) : Option[TaskDescription] = { - rootPool.receiveOffer("execId_1", "hostname_1", 1) + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedLeafSchedulable() + for (taskSet <- taskSetQueue) + { + taskSet.slaveOffer("execId_1", "hostname_1", 1) match { + case Some(task) => + return task.taskSetId.toInt + case None => {} + } + } + -1 } def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - receiveOffer(rootPool) match { - case Some(task) => - assert(task.taskSetId.toInt === expectedTaskSetId) - case _ => - } + assert(resourceOffer(rootPool) === expectedTaskSetId) } test("FIFO Scheduler Test") { @@ -105,9 +117,9 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { schedulableBuilder.addTaskSetManager(taskSetManager2, null) checkTaskSetId(rootPool, 0) - receiveOffer(rootPool) + resourceOffer(rootPool) checkTaskSetId(rootPool, 1) - receiveOffer(rootPool) + resourceOffer(rootPool) taskSetManager1.abort() checkTaskSetId(rootPool, 2) } -- cgit v1.2.3 From f742435f18f65e1dbf6235dd49f93b10f22cfe4b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 May 2013 14:31:03 -0700 Subject: Removed the duplicated netty dependency in SBT build file. --- project/SparkBuild.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 234b021c93..0ea23b446f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -142,7 +142,6 @@ object SparkBuild extends Build { ), libraryDependencies ++= Seq( - "io.netty" % "netty" % "3.5.3.Final", "com.google.guava" % "guava" % "11.0.1", "log4j" % "log4j" % "1.2.16", "org.slf4j" % "slf4j-api" % slf4jVersion, -- cgit v1.2.3 From 61cf17623835007114ee69394999faaba8a46206 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 May 2013 14:31:26 -0700 Subject: Added dependency on netty-all in Maven. --- core/pom.xml | 4 ++++ pom.xml | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index 57a95328c3..d8687bf991 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -87,6 +87,10 @@ org.apache.mesos mesos + + io.netty + netty-all + log4j log4j diff --git a/pom.xml b/pom.xml index d7cdc591cf..eda18fdd12 100644 --- a/pom.xml +++ b/pom.xml @@ -256,6 +256,11 @@ mesos ${mesos.version} + + io.netty + netty-all + 4.0.0.Beta2 + org.apache.derby derby -- cgit v1.2.3 From 43644a293f5faec088530cf3a84d3680f2a103af Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 May 2013 14:31:38 -0700 Subject: Only check for repl classes if the user is running the repl. Otherwise, check for core classes in run. This fixed the problem that core tests depend on whether repl module is compiled or not. --- run | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/run b/run index c744bbd3dc..c0065c53f1 100755 --- a/run +++ b/run @@ -102,12 +102,18 @@ STREAMING_DIR="$FWDIR/streaming" PYSPARK_DIR="$FWDIR/python" # Exit if the user hasn't compiled Spark -if [ ! -e "$REPL_DIR/target" ]; then - echo "Failed to find Spark classes in $REPL_DIR/target" >&2 +if [ ! -e "$CORE_DIR/target" ]; then + echo "Failed to find Spark classes in $CORE_DIR/target" >&2 echo "You need to compile Spark before running this program" >&2 exit 1 fi +if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then + echo "Failed to find Spark classes in $REPL_DIR/target" >&2 + echo "You need to compile Spark repl module before running this program" >&2 + exit 1 +fi + # Build up classpath CLASSPATH="$SPARK_CLASSPATH" CLASSPATH="$CLASSPATH:$FWDIR/conf" -- cgit v1.2.3 From 3b3300383a6ccb9d8b62243c7814eb6c2e1ab313 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 May 2013 16:51:28 -0700 Subject: Updated Scala version in docs generation ruby script. --- docs/_plugins/copy_api_dirs.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index d77e53963c..c10ae595de 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -18,7 +18,7 @@ if ENV['SKIP_API'] != '1' # Copy over the scaladoc from each project into the docs directory. # This directory will be copied over to _site when `jekyll` command is run. projects.each do |project_name| - source = "../" + project_name + "/target/scala-2.9.2/api" + source = "../" + project_name + "/target/scala-2.9.3/api" dest = "api/" + project_name puts "echo making directory " + dest -- cgit v1.2.3 From dc146406aefa4285d2a2a5d9d45f2ef883e9ef73 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 16 May 2013 17:07:14 -0700 Subject: Updated Scala version in docs generation ruby script. --- docs/_plugins/copy_api_dirs.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index d77e53963c..c10ae595de 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -18,7 +18,7 @@ if ENV['SKIP_API'] != '1' # Copy over the scaladoc from each project into the docs directory. # This directory will be copied over to _site when `jekyll` command is run. projects.each do |project_name| - source = "../" + project_name + "/target/scala-2.9.2/api" + source = "../" + project_name + "/target/scala-2.9.3/api" dest = "api/" + project_name puts "echo making directory " + dest -- cgit v1.2.3 From da2642bead2eaf15bfcc28520858cf212d5975a4 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Fri, 17 May 2013 06:58:46 +0530 Subject: Fix example jar name --- docs/running-on-yarn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2e46ff0ed1..3946100247 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -25,7 +25,7 @@ Assembly of the jar Ex: ./sbt/sbt clean assembly The assembled jar would typically be something like : -./streaming/target/spark-streaming-.jar +./core/target/spark-core-assembly-0.8.0-SNAPSHOT.jar - Building spark assembled jar via sbt. -- cgit v1.2.3 From d19753b9c78857acae441dce3133fbb6c5855f95 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Sat, 18 May 2013 06:45:19 +0800 Subject: expose TaskSetManager type to resourceOffer function in ClusterScheduler --- .../spark/scheduler/cluster/ClusterScheduler.scala | 14 +-- .../main/scala/spark/scheduler/cluster/Pool.scala | 12 +-- .../spark/scheduler/cluster/Schedulable.scala | 3 +- .../spark/scheduler/cluster/TaskDescription.scala | 1 - .../spark/scheduler/cluster/TaskSetManager.scala | 12 +-- .../spark/scheduler/ClusterSchedulerSuite.scala | 112 ++++++++++++--------- 6 files changed, 84 insertions(+), 70 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 4caafcc1d3..e6399a3547 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -165,24 +165,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray var launchedTask = false - val sortedLeafSchedulable = rootPool.getSortedLeafSchedulable() - for (schedulable <- sortedLeafSchedulable) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format(schedulable.parent.name,schedulable.name,schedulable.runningTasks)) + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) } - for (schedulable <- sortedLeafSchedulable) { + for (manager <- sortedTaskSetQueue) { do { launchedTask = false for (i <- 0 until offers.size) { var launchedTask = true val execId = offers(i).executorId val host = offers(i).hostname - schedulable.slaveOffer(execId,host,availableCpus(i)) match { + manager.slaveOffer(execId,host,availableCpus(i)) match { case Some(task) => tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = task.taskSetId - taskSetTaskIds(task.taskSetId) += tid + taskIdToTaskSetId(tid) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId executorsByHost(host) += execId diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala index ae603e7dd9..4dc15f413c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala @@ -75,17 +75,13 @@ private[spark] class Pool( return shouldRevive } - override def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { - return None - } - - override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { - var leafSchedulableQueue = new ArrayBuffer[Schedulable] + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { - leafSchedulableQueue ++= schedulable.getSortedLeafSchedulable() + sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() } - return leafSchedulableQueue + return sortedTaskSetQueue } override def increaseRunningTasks(taskNum: Int) { diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala index c620588e14..6bb7525b49 100644 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala @@ -21,7 +21,6 @@ private[spark] trait Schedulable { def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable def executorLost(executorId: String, host: String): Unit - def slaveOffer(execId: String, host: String, avaiableCpus: Double): Option[TaskDescription] def checkSpeculatableTasks(): Boolean - def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index cdd004c94b..b41e951be9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -5,7 +5,6 @@ import spark.util.SerializableBuffer private[spark] class TaskDescription( val taskId: Long, - val taskSetId: String, val executorId: String, val name: String, _serializedTask: ByteBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 80edbe77a1..b9d2dbf487 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -198,7 +198,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - override def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -234,7 +234,7 @@ private[spark] class TaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, taskSet.id, execId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) } case _ => } @@ -398,10 +398,10 @@ private[spark] class TaskSetManager( //nothing } - override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { - var leafSchedulableQueue = new ArrayBuffer[Schedulable] - leafSchedulableQueue += this - return leafSchedulableQueue + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue } override def executorLost(execId: String, hostname: String) { diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index 8426be7575..956cc7421c 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -13,18 +13,20 @@ import java.util.Properties class DummyTaskSetManager( initPriority: Int, initStageId: Int, - initNumTasks: Int) - extends Schedulable { - - var parent: Schedulable = null - var weight = 1 - var minShare = 2 - var runningTasks = 0 - var priority = initPriority - var stageId = initStageId - var name = "TaskSet_"+stageId - var numTasks = initNumTasks - var tasksFinished = 0 + initNumTasks: Int, + clusterScheduler: ClusterScheduler, + taskSet: TaskSet) + extends TaskSetManager(clusterScheduler,taskSet) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksFinished = 0 override def increaseRunningTasks(taskNum: Int) { runningTasks += taskNum @@ -41,11 +43,11 @@ class DummyTaskSetManager( } override def addSchedulable(schedulable: Schedulable) { - } - + } + override def removeSchedulable(schedulable: Schedulable) { } - + override def getSchedulableByName(name: String): Schedulable = { return null } @@ -65,11 +67,11 @@ class DummyTaskSetManager( return true } - override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { - var leafSchedulableQueue = new ArrayBuffer[Schedulable] - leafSchedulableQueue += this - return leafSchedulableQueue - } +// override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { +// var leafSchedulableQueue = new ArrayBuffer[Schedulable] +// leafSchedulableQueue += this +// return leafSchedulableQueue +// } def taskFinished() { decreaseRunningTasks(1) @@ -85,10 +87,28 @@ class DummyTaskSetManager( } } +class DummyTask(stageId: Int) extends Task[Int](stageId) +{ + def run(attemptId: Long): Int = { + return 0 + } +} + class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { - + + val sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + tasks += task + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int): DummyTaskSetManager = { + new DummyTaskSetManager(priority, stage, numTasks, clusterScheduler, taskSet) + } + def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedLeafSchedulable() + val taskSetQueue = rootPool.getSortedTaskSetQueue() for (taskSet <- taskSetQueue) { taskSet.slaveOffer("execId_1", "hostname_1", 1) match { @@ -109,13 +129,13 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) schedulableBuilder.buildPools() - val taskSetManager0 = new DummyTaskSetManager(0, 0, 2) - val taskSetManager1 = new DummyTaskSetManager(0, 1, 2) - val taskSetManager2 = new DummyTaskSetManager(0, 2, 2) + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2) schedulableBuilder.addTaskSetManager(taskSetManager0, null) schedulableBuilder.addTaskSetManager(taskSetManager1, null) schedulableBuilder.addTaskSetManager(taskSetManager2, null) - + checkTaskSetId(rootPool, 0) resourceOffer(rootPool) checkTaskSetId(rootPool, 1) @@ -130,7 +150,7 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val schedulableBuilder = new FairSchedulableBuilder(rootPool) schedulableBuilder.buildPools() - + assert(rootPool.getSchedulableByName("default") != null) assert(rootPool.getSchedulableByName("1") != null) assert(rootPool.getSchedulableByName("2") != null) @@ -146,16 +166,16 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { properties1.setProperty("spark.scheduler.cluster.fair.pool","1") val properties2 = new Properties() properties2.setProperty("spark.scheduler.cluster.fair.pool","2") - - val taskSetManager10 = new DummyTaskSetManager(1, 0, 1) - val taskSetManager11 = new DummyTaskSetManager(1, 1, 1) - val taskSetManager12 = new DummyTaskSetManager(1, 2, 2) + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2) schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = new DummyTaskSetManager(2, 3, 2) - val taskSetManager24 = new DummyTaskSetManager(2, 4, 2) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2) schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) @@ -190,27 +210,27 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) pool1.addSchedulable(pool10) pool1.addSchedulable(pool11) - - val taskSetManager000 = new DummyTaskSetManager(0, 0, 5) - val taskSetManager001 = new DummyTaskSetManager(0, 1, 5) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5) pool00.addSchedulable(taskSetManager000) pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = new DummyTaskSetManager(1, 2, 5) - val taskSetManager011 = new DummyTaskSetManager(1, 3, 5) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5) pool01.addSchedulable(taskSetManager010) pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = new DummyTaskSetManager(2, 4, 5) - val taskSetManager101 = new DummyTaskSetManager(2, 5, 5) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5) pool10.addSchedulable(taskSetManager100) pool10.addSchedulable(taskSetManager101) - val taskSetManager110 = new DummyTaskSetManager(3, 6, 5) - val taskSetManager111 = new DummyTaskSetManager(3, 7, 5) + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5) pool11.addSchedulable(taskSetManager110) pool11.addSchedulable(taskSetManager111) - + checkTaskSetId(rootPool, 0) checkTaskSetId(rootPool, 4) checkTaskSetId(rootPool, 6) -- cgit v1.2.3 From 0eab7a78b90e2593075c479282f631a5a20e77a9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 17 May 2013 18:05:46 -0700 Subject: Fixed a couple typos and formating problems in the YARN documentation. --- docs/running-on-yarn.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3946100247..66fb8d73e8 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -14,29 +14,31 @@ Ex: mvn -Phadoop2-yarn clean install We need a consolidated spark core jar (which bundles all the required dependencies) to run Spark jobs on a yarn cluster. This can be built either through sbt or via maven. -- Building spark assembled jar via sbt. -It is a manual process of enabling it in project/SparkBuild.scala. +- Building spark assembled jar via sbt. + It is a manual process of enabling it in project/SparkBuild.scala. Please comment out the HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN variables before the line 'For Hadoop 2 YARN support' Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support. -Assembly of the jar Ex: -./sbt/sbt clean assembly +Assembly of the jar Ex: + + ./sbt/sbt clean assembly The assembled jar would typically be something like : -./core/target/spark-core-assembly-0.8.0-SNAPSHOT.jar +`./core/target/spark-core-assembly-0.8.0-SNAPSHOT.jar` -- Building spark assembled jar via sbt. -Use the hadoop2-yarn profile and execute the package target. +- Building spark assembled jar via Maven. + Use the hadoop2-yarn profile and execute the package target. Something like this. Ex: -$ mvn -Phadoop2-yarn clean package -DskipTests=true + + mvn -Phadoop2-yarn clean package -DskipTests=true This will build the shaded (consolidated) jar. Typically something like : -./repl-bin/target/spark-repl-bin--shaded-hadoop2-yarn.jar +`./repl-bin/target/spark-repl-bin--shaded-hadoop2-yarn.jar` # Preparations -- cgit v1.2.3 From 8d78c5f89f25d013c997c03587193f3d87a268b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 17 May 2013 18:51:35 -0700 Subject: Changed the logging level from info to warning when addJar(null) is called. --- core/src/main/scala/spark/SparkContext.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 736b5485b7..69b4c5d20d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -538,16 +538,17 @@ class SparkContext( */ def addJar(path: String) { if (null == path) { - logInfo("null specified as parameter to addJar") - return - } - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) - case _ => path + logWarning("null specified as parameter to addJar", + new SparkException("null specified as parameter to addJar")) + } else { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) + case _ => path + } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) } - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) } /** -- cgit v1.2.3 From e7982c798efccd523165d0e347c7912ba14fcdd7 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Sat, 18 May 2013 16:11:29 -0700 Subject: Exclude old versions of Netty from Maven-based build --- pom.xml | 6 ++++++ streaming/pom.xml | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/pom.xml b/pom.xml index eda18fdd12..6ee64d07c2 100644 --- a/pom.xml +++ b/pom.xml @@ -565,6 +565,12 @@ org.apache.avro avro-ipc 1.7.1.cloudera.2 + + + org.jboss.netty + netty + + diff --git a/streaming/pom.xml b/streaming/pom.xml index 08ff3e2ae1..4dc9a19d51 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -41,6 +41,12 @@ org.apache.flume flume-ng-sdk 1.2.0 + + + org.jboss.netty + netty + + com.github.sgroschupf -- cgit v1.2.3 From ecd6d75c6a88232c40070baed3dd67bdf77f0c69 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Tue, 21 May 2013 06:49:23 +0800 Subject: fix bug of unit tests --- .../spark/scheduler/cluster/ClusterScheduler.scala | 6 +- .../spark/scheduler/ClusterSchedulerSuite.scala | 72 ++++++++++++---------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 9547f4f6dd..053d4b8e4a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -352,7 +352,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) executorsByHostPort(hostPort) += execId availableCpus(i) -= 1 launchedTask = true - + case None => {} } } @@ -373,7 +373,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } while (launchedTask) } - + if (tasks.size > 0) { hasLaunchedTask = true } @@ -522,7 +522,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) hostPortsAlive -= hostPort hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort) } - + val execs = executorsByHostPort.getOrElse(hostPort, new HashSet) execs -= executorId if (execs.isEmpty) { diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index 7af749fb5c..a39418b716 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -67,12 +67,6 @@ class DummyTaskSetManager( return true } -// override def getSortedLeafSchedulable(): ArrayBuffer[Schedulable] = { -// var leafSchedulableQueue = new ArrayBuffer[Schedulable] -// leafSchedulableQueue += this -// return leafSchedulableQueue -// } - def taskFinished() { decreaseRunningTasks(1) tasksFinished +=1 @@ -94,17 +88,10 @@ class DummyTask(stageId: Int) extends Task[Int](stageId) } } -class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { - - val sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new DummyTask(0) - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - tasks += task +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int): DummyTaskSetManager = { - new DummyTaskSetManager(priority, stage, numTasks, clusterScheduler, taskSet) + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = { + new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet) } def resourceOffer(rootPool: Pool): Int = { @@ -125,13 +112,20 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { } test("FIFO Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) schedulableBuilder.buildPools() - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2) + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager0, null) schedulableBuilder.addTaskSetManager(taskSetManager1, null) schedulableBuilder.addTaskSetManager(taskSetManager2, null) @@ -145,6 +139,13 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { } test("Fair Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.fairscheduler.allocation.file", xmlPath) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) @@ -167,15 +168,15 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { val properties2 = new Properties() properties2.setProperty("spark.scheduler.cluster.fair.pool","2") - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2) + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2) + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) @@ -195,6 +196,13 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { } test("Nested Pool Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new DummyTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) @@ -211,23 +219,23 @@ class ClusterSchedulerSuite extends FunSuite with BeforeAndAfter { pool1.addSchedulable(pool10) pool1.addSchedulable(pool11) - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5) + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) pool00.addSchedulable(taskSetManager000) pool00.addSchedulable(taskSetManager001) - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5) + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) pool01.addSchedulable(taskSetManager010) pool01.addSchedulable(taskSetManager011) - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5) + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) pool10.addSchedulable(taskSetManager100) pool10.addSchedulable(taskSetManager101) - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5) + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) pool11.addSchedulable(taskSetManager110) pool11.addSchedulable(taskSetManager111) -- cgit v1.2.3 From 3217d486f7fdd590250f2efee567e4779e130d34 Mon Sep 17 00:00:00 2001 From: Ethan Jewett Date: Mon, 20 May 2013 19:41:38 -0500 Subject: Add hBase dependency to examples POM --- examples/pom.xml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/pom.xml b/examples/pom.xml index c42d2bcdb9..0fbb5a3d5d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -67,6 +67,11 @@ hadoop-core provided + + org.apache.hbase + hbase + 0.94.6 + @@ -105,6 +110,11 @@ hadoop-client provided + + org.apache.hbase + hbase + 0.94.6 + -- cgit v1.2.3 From 786c97b87c9d3074796e2d931635d9c6f72b9704 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 27 Feb 2013 10:06:05 -0800 Subject: DistributedSuite: remove dead test code --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 068bb6ca4f..0866fb47b3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -222,7 +222,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) assert(data.count === 2) // force executors to start - val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).collect.size === 2) } -- cgit v1.2.3 From f350f14084dd04a2ea77e92e35b3cf415ef72202 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 5 Feb 2013 17:08:55 -0800 Subject: Use ARRAY_SAMPLE_SIZE constant instead of 100.0 --- core/src/main/scala/spark/SizeEstimator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index d4e1157250..f8a4c4e489 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging { val elem = JArray.get(array, index) size += SizeEstimator.estimate(elem, state.visited) } - state.size += ((length / 100.0) * size).toLong + state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong } } } -- cgit v1.2.3 From bd3ea8f2a66de5ddc12dc1b2273e675d0abb8393 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 24 May 2013 14:26:19 +0800 Subject: fix CheckpointRDD getPreferredLocations java.io.FileNotFoundException --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 43ee39c993..377b1bdbe0 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -43,7 +43,7 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri checkpointData.get.cpFile = Some(checkpointPath) override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath)) + val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) val locations = fs.getFileBlockLocations(status, 0, status.getLen) locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } -- cgit v1.2.3 From cda2b150412314c47c2c24883111bfc441c3a3a2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 May 2013 13:05:06 -0700 Subject: Use ec2-metadata in start-slave.sh. PR #419 applied the same change, but only to start-master.sh, so some workers were still starting their web UI's using internal addresses. This should finally fix SPARK-613. --- bin/start-slave.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bin/start-slave.sh b/bin/start-slave.sh index 616c76e4ee..26b5b9d462 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -6,7 +6,8 @@ bin=`cd "$bin"; pwd` # Set SPARK_PUBLIC_DNS so slaves can be linked in master web UI if [ "$SPARK_PUBLIC_DNS" = "" ]; then # If we appear to be running on EC2, use the public address by default: - if [[ `hostname` == *ec2.internal ]]; then + # NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname + if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname` fi fi -- cgit v1.2.3 From 6ea085169d8ba2d09ca9236273d65238b8411f04 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 May 2013 14:08:37 -0700 Subject: Fixed the bug that shuffle serializer is ignored by the new shuffle block iterators for local blocks. Also added a unit test for that. --- .../scala/spark/storage/BlockFetcherIterator.scala | 2 +- core/src/test/scala/spark/ShuffleSuite.scala | 21 ++++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 43f835237c..88eed0d8c8 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -163,7 +163,7 @@ object BlockFetcherIterator { // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight for (id <- localBlockIds) { - getLocal(id) match { + getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight results.put(new FetchResult(id, 0, () => iter)) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 2b2a90defa..fdee7ca384 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -99,7 +99,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } - + test("reduceByKey with partitioner") { sc = new SparkContext("local", "test") val p = new Partitioner() { @@ -272,7 +272,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } // partitionBy so we have a narrow dependency val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency + // more partitions/no partitioner so a shuffle dependency val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) val c = a.subtract(b) assert(c.collect().toSet === Set((1, "a"), (3, "c"))) @@ -298,18 +298,33 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } // partitionBy so we have a narrow dependency val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency + // more partitions/no partitioner so a shuffle dependency val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) val c = a.subtractByKey(b) assert(c.collect().toSet === Set((1, "a"), (1, "a"))) assert(c.partitioner.get === p) } + test("shuffle serializer") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[1,2,512]", "test") + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName) + assert(c.count === 10) + } } object ShuffleSuite { + def mergeCombineException(x: Int, y: Int): Int = { throw new SparkException("Exception for map-side combine.") x + y } + + class NonJavaSerializableClass(val value: Int) } -- cgit v1.2.3 From 26962c9340ac92b11d43e87200e699471d0b6330 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 May 2013 16:39:33 -0700 Subject: Automatically configure Netty port. This makes unit tests using local-cluster pass. Previously they were failing because Netty was trying to bind to the same port for all processes. Pair programmed with @shivaram. --- .../main/java/spark/network/netty/FileServer.java | 68 ++++++++++++++++------ .../scala/spark/network/netty/ShuffleSender.scala | 23 ++++---- .../scala/spark/storage/BlockFetcherIterator.scala | 3 +- .../main/scala/spark/storage/BlockManager.scala | 13 +++-- .../main/scala/spark/storage/BlockManagerId.scala | 32 +++++++--- core/src/main/scala/spark/storage/DiskStore.scala | 52 ++++------------- .../test/scala/spark/MapOutputTrackerSuite.scala | 28 ++++----- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 14 ++--- .../scala/spark/storage/BlockManagerSuite.scala | 6 +- 10 files changed, 129 insertions(+), 112 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java index 647b26bf8a..dd3f12561c 100644 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -1,51 +1,83 @@ package spark.network.netty; +import java.net.InetSocketAddress; + import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelOption; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioServerSocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * Server that accept the path of a file an echo back its content. */ class FileServer { + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + private ServerBootstrap bootstrap = null; - private Channel channel = null; - private PathResolver pResolver; + private ChannelFuture channelFuture = null; + private int port = 0; + private Thread blockingThread = null; - public FileServer(PathResolver pResolver) { - this.pResolver = pResolver; - } + public FileServer(PathResolver pResolver, int port) { + InetSocketAddress addr = new InetSocketAddress(port); - public void run(int port) { // Configure the server. bootstrap = new ServerBootstrap(); - try { - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) .channel(OioServerSocketChannel.class) .option(ChannelOption.SO_BACKLOG, 100) .option(ChannelOption.SO_RCVBUF, 1500) .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channel = bootstrap.bind(port).sync().channel(); - channel.closeFuture().sync(); - } catch (InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } finally{ + // Start the server. + channelFuture = bootstrap.bind(addr); + this.port = addr.getPort(); + } + + /** + * Start the file server asynchronously in a new thread. + */ + public void start() { + try { + blockingThread = new Thread() { + public void run() { + try { + Channel channel = channelFuture.sync().channel(); + channel.closeFuture().sync(); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); + } + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); + } finally { bootstrap.shutdown(); } } + public int getPort() { + return port; + } + public void stop() { - if (channel!=null) { - channel.close(); + if (blockingThread != null) { + blockingThread.stop(); + blockingThread = null; + } + if (channelFuture != null) { + channelFuture.channel().closeFuture(); + channelFuture = null; } if (bootstrap != null) { bootstrap.shutdown(); + bootstrap = null; } } } diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala index dc87fefc56..d6fa4b1e80 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -5,23 +5,22 @@ import java.io.File import spark.Logging -private[spark] class ShuffleSender(val port: Int, val pResolver: PathResolver) extends Logging { - val server = new FileServer(pResolver) +private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - Runtime.getRuntime().addShutdownHook( - new Thread() { - override def run() { - server.stop() - } - } - ) + val server = new FileServer(pResolver, portIn) + server.start() - def start() { - server.run(port) + def stop() { + server.stop() } + + def port: Int = server.getPort() } +/** + * An application for testing the shuffle sender as a standalone program. + */ private[spark] object ShuffleSender { def main(args: Array[String]) { @@ -50,7 +49,5 @@ private[spark] object ShuffleSender { } } val sender = new ShuffleSender(port, pResovler) - - sender.start() } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 88eed0d8c8..95308c7282 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -272,8 +272,7 @@ object BlockFetcherIterator { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId( - req.address.host, System.getProperty("spark.shuffle.sender.port", "6653").toInt) + val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) val cpier = new ShuffleCopier cpier.getBlocks(cmId, req.blocks, putResult) logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 40d608628e..d35c43f194 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -94,11 +94,16 @@ private[spark] class BlockManager( private[storage] val diskStore: DiskStore = new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. + private val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + private val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt + private val nettyPort = if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port) + executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -266,7 +271,6 @@ private[spark] class BlockManager( } } - /** * Get locations of an array of blocks. */ @@ -274,7 +278,7 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) - return locations + locations } /** @@ -971,8 +975,7 @@ private[spark] object BlockManager extends Logging { assert (env != null || blockManagerMaster != null) val locationBlockIds: Seq[Seq[BlockManagerId]] = if (env != null) { - val blockManager = env.blockManager - blockManager.getLocationBlockIds(blockIds) + env.blockManager.getLocationBlockIds(blockIds) } else { blockManagerMaster.getLocations(blockIds) } diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index f4a2181490..1e557d6148 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -7,18 +7,19 @@ import spark.Utils /** * This class represent an unique identifier for a BlockManager. * The first 2 constructors of this class is made private to ensure that - * BlockManagerId objects can be created only using the factory method in - * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects. + * BlockManagerId objects can be created only using the apply method in + * the companion object. This allows de-duplication of ID objects. * Also, constructor parameters are private to ensure that parameters cannot * be modified from outside this class. */ private[spark] class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int + private var port_ : Int, + private var nettyPort_ : Int ) extends Externalizable { - private def this() = this(null, null, 0) // For deserialization only + private def this() = this(null, null, 0, 0) // For deserialization only def executorId: String = executorId_ @@ -39,28 +40,32 @@ private[spark] class BlockManagerId private ( def port: Int = port_ + def nettyPort: Int = nettyPort_ + override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) + out.writeInt(nettyPort_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() + nettyPort_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) + override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host + executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort case _ => false } @@ -69,8 +74,17 @@ private[spark] class BlockManagerId private ( private[spark] object BlockManagerId { - def apply(execId: String, host: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port)) + /** + * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton. + * + * @param execId ID of the executor. + * @param host Host name of the block manager. + * @param port Port of the block manager. + * @param nettyPort Optional port for the Netty-based shuffle sender. + * @return A new [[spark.storage.BlockManagerId]]. + */ + def apply(execId: String, host: String, port: Int, nettyPort: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 933eeaa216..57d4dafefc 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -82,22 +82,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - var shuffleSender : Thread = null - val thisInstance = this + var shuffleSender : ShuffleSender = null // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. val localDirs = createLocalDirs() val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean - addShutdownHook() - if(useNetty){ - startShuffleBlockSender() - } - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { new DiskBlockObjectWriter(blockId, serializer, bufferSize) @@ -274,8 +267,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) localDirs.foreach { localDir => if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) } - if (useNetty && shuffleSender != null) + if (shuffleSender != null) { shuffleSender.stop + } } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } @@ -283,39 +277,17 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) }) } - private def startShuffleBlockSender() { - try { - val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt - - val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - thisInstance.getFile(blockId).getAbsolutePath() - } - } - shuffleSender = new Thread { - override def run() = { - val sender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ port) - sender.start - } - } - shuffleSender.setDaemon(true) - shuffleSender.start - - } catch { - case interrupted: InterruptedException => - logInfo("Runner thread for ShuffleBlockSender interrupted") - - case e: Exception => { - logError("Error running ShuffleBlockSender ", e) - if (shuffleSender != null) { - shuffleSender.stop - shuffleSender = null + private[storage] def startShuffleBlockSender(port: Int): Int = { + val pResolver = new PathResolver { + override def getAbsolutePath(blockId: String): String = { + if (!blockId.startsWith("shuffle_")) { + return null } + DiskStore.this.getFile(blockId).getAbsolutePath() } } + shuffleSender = new ShuffleSender(port, pResolver) + logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) + shuffleSender.port } } diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index b5cedc0b68..6e585e1c3a 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -8,7 +8,7 @@ import spark.storage.BlockManagerId import spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { - + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -45,13 +45,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), + (BlockManagerId("b", "hostB", 1000, 0), size10000))) tracker.stop() } @@ -64,14 +64,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -88,12 +88,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTracker() masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") - + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() slaveTracker.trackerActor = slaveSystem.actorFor( "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") - + masterTracker.registerShuffle(10, 1) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) @@ -102,13 +102,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index fdee7ca384..58c834c735 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -326,5 +326,5 @@ object ShuffleSuite { x + y } - class NonJavaSerializableClass(val value: Int) + class NonJavaSerializableClass(val value: Int) extends Serializable } diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 16554eac6e..30e6fef950 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -44,7 +44,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration) - taskSets += taskSet + taskSets += taskSet } override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 @@ -164,7 +164,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } } } - + /** Sends the rdd to the scheduler for scheduling. */ private def submit( rdd: RDD[_], @@ -174,7 +174,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont listener: JobListener = listener) { runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) } - + /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { runEvent(TaskSetFailed(taskSet, message)) @@ -209,11 +209,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) } - + test("run trivial job w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - submit(finalRdd, Array(0)) + submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) } @@ -250,7 +250,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) } - + test("run trivial shuffle with fetch failure") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -398,6 +398,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) private def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345) + BlockManagerId("exec-" + host, host, 12345, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 71d1f0bcc8..bff2475686 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -99,9 +99,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = BlockManagerId("e1", "XXX", 1) - val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 - val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1, 0) + val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") -- cgit v1.2.3 From 6bbbe012877115eab084fea09baf677abaf52f2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 May 2013 16:51:45 -0700 Subject: Fixed a stupid mistake that NonJavaSerializableClass was made Java serializable. --- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 58c834c735..fdee7ca384 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -326,5 +326,5 @@ object ShuffleSuite { x + y } - class NonJavaSerializableClass(val value: Int) extends Serializable + class NonJavaSerializableClass(val value: Int) } -- cgit v1.2.3 From a674d67c0aebb940e3b816e2307206115baec175 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 28 May 2013 16:24:05 -0500 Subject: Fix start-slave not passing instance number to spark-daemon. --- bin/start-slave.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/start-slave.sh b/bin/start-slave.sh index 26b5b9d462..dfcbc6981b 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -12,4 +12,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" +"$bin"/spark-daemon.sh start spark.deploy.worker.Worker 1 "$@" -- cgit v1.2.3 From 4fe1fbdd51f781157138ffd35da5834366379688 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 28 May 2013 16:26:32 -0500 Subject: Remove unused addIfNoPort. --- core/src/main/scala/spark/Utils.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index c1495d5317..84626df553 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -403,17 +403,6 @@ private object Utils extends Logging { hostPortParseResults.get(hostPort) } - def addIfNoPort(hostPort: String, port: Int): String = { - if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port) - - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - val indx: Int = hostPort.lastIndexOf(':') - if (-1 != indx) return hostPort - - hostPort + ":" + port - } - private[spark] val daemonThreadFactory: ThreadFactory = new ThreadFactoryBuilder().setDaemon(true).build() -- cgit v1.2.3 From fbc1ab346867d5c81dc59e4c8d85aeda2f516ce2 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 28 May 2013 16:27:16 -0700 Subject: Couple of Netty fixes a. Fix the port number by reading it from the bound channel b. Fix the shutdown sequence to make sure we actually block on the channel c. Fix the unit test to use two JVMs. --- .../main/java/spark/network/netty/FileServer.java | 45 ++++++++++++---------- core/src/test/scala/spark/ShuffleSuite.scala | 14 ++++++- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java index dd3f12561c..dd3a557ae5 100644 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -37,29 +37,33 @@ class FileServer { .childHandler(new FileServerChannelInitializer(pResolver)); // Start the server. channelFuture = bootstrap.bind(addr); - this.port = addr.getPort(); + try { + // Get the address we bound to. + InetSocketAddress boundAddress = + ((InetSocketAddress) channelFuture.sync().channel().localAddress()); + this.port = boundAddress.getPort(); + } catch (InterruptedException ie) { + this.port = 0; + } } /** * Start the file server asynchronously in a new thread. */ public void start() { - try { - blockingThread = new Thread() { - public void run() { - try { - Channel channel = channelFuture.sync().channel(); - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } + blockingThread = new Thread() { + public void run() { + try { + channelFuture.channel().closeFuture().sync(); + LOG.info("FileServer exiting"); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } finally { - bootstrap.shutdown(); - } + // NOTE: bootstrap is shutdown in stop() + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); } public int getPort() { @@ -67,17 +71,16 @@ class FileServer { } public void stop() { - if (blockingThread != null) { - blockingThread.stop(); - blockingThread = null; - } + // Close the bound channel. if (channelFuture != null) { - channelFuture.channel().closeFuture(); + channelFuture.channel().close(); channelFuture = null; } + // Shutdown bootstrap. if (bootstrap != null) { bootstrap.shutdown(); bootstrap = null; } + // TODO: Shutdown all accepted channels as well ? } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index fdee7ca384..a4fe14b9ae 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -305,9 +305,20 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.partitioner.get === p) } + test("shuffle local cluster") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + val a = sc.parallelize(1 to 10, 2) + val b = a.map { + x => (x, x * 2) + } + val c = new ShuffledRDD(b, new HashPartitioner(3)) + assert(c.count === 10) + } + test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[1,2,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test") val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) @@ -317,6 +328,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName) assert(c.count === 10) } + } object ShuffleSuite { -- cgit v1.2.3 From b79b10a6d60a7f1f199e6bddd1243a05c57526ad Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 29 May 2013 00:52:55 -0700 Subject: Flush serializer to fix zero-size kryo blocks bug. Also convert the local-cluster test case to check for non-zero block sizes --- core/src/main/scala/spark/storage/DiskStore.scala | 2 ++ core/src/test/scala/spark/ShuffleSuite.scala | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 57d4dafefc..1829c2f92e 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() bs.flush() val prevPos = lastValidPosition lastValidPosition = channel.position() diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index a4fe14b9ae..271f4a4e44 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -305,15 +305,27 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.partitioner.get === p) } - test("shuffle local cluster") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks + test("shuffle non-zero block size") { sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 + val a = sc.parallelize(1 to 10, 2) - val b = a.map { - x => (x, x * 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) } - val c = new ShuffledRDD(b, new HashPartitioner(3)) + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS), + classOf[spark.KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) + } } test("shuffle serializer") { -- cgit v1.2.3 From 618c8cae1ee5dede98824823e00f7863571c0e57 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 29 May 2013 13:09:58 -0700 Subject: Skip fetching zero-sized blocks in OIO. Also unify splitLocalRemoteBlocks for netty/nio and add a test case --- .../scala/spark/storage/BlockFetcherIterator.scala | 61 +++++----------------- core/src/test/scala/spark/ShuffleSuite.scala | 27 ++++++++++ 2 files changed, 39 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 95308c7282..1d69d658f7 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -124,6 +124,7 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { @@ -140,8 +141,15 @@ object BlockFetcherIterator { var curBlocks = new ArrayBuffer[(String, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0) { + _totalBlocks -= 1 + } else { + throw new BlockException(blockId, "Negative block size " + size) + } if (curRequestSize >= minRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) @@ -155,6 +163,8 @@ object BlockFetcherIterator { } } } + logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + + originalTotalBlocks + " blocks") remoteRequests } @@ -278,53 +288,6 @@ object BlockFetcherIterator { logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } - override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks; - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - if (size > 0) { - curBlocks += ((blockId, size)) - curRequestSize += size - } else if (size == 0) { - //here we changes the totalBlocks - _totalBlocks -= 1 - } else { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") - remoteRequests - } - private var copiers: List[_ <: Thread] = null override def initialize() { diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index fdee7ca384..4e50ae2ca9 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -317,6 +317,33 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName) assert(c.count === 10) } + + test("zero sized blocks") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite { -- cgit v1.2.3 From 19fd6d54c012bd9f73620e9b817f4975de162277 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 29 May 2013 17:29:34 -0700 Subject: Also flush serializer in revertPartialWrites --- core/src/main/scala/spark/storage/DiskStore.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 1829c2f92e..c7281200e7 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -70,6 +70,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def revertPartialWrites() { // Discard current writes. We do this by flushing the outstanding writes and // truncate the file to the last valid position. + objOut.flush() bs.flush() channel.truncate(lastValidPosition) } -- cgit v1.2.3 From ecceb101d3019ef511c42a8a8a3bb0e46520ffef Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Thu, 30 May 2013 10:43:01 +0800 Subject: implement FIFO and fair scheduler for spark local mode --- .../spark/scheduler/cluster/ClusterScheduler.scala | 2 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 734 +++++++++++++++++++++ .../spark/scheduler/cluster/TaskSetManager.scala | 733 +------------------- .../spark/scheduler/local/LocalScheduler.scala | 386 +++++++++-- .../spark/scheduler/ClusterSchedulerSuite.scala | 2 +- 5 files changed, 1057 insertions(+), 800 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 053d4b8e4a..3a0c29b27f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet) + val manager = new ClusterTaskSetManager(this, taskSet) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala new file mode 100644 index 0000000000..ec4041ab86 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -0,0 +1,734 @@ +package spark.scheduler.cluster + +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min + +import spark._ +import spark.scheduler._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + // process local is expected to be used ONLY within tasksetmanager for now. + val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + // Must not be the constraint. + assert (constraint != TaskLocality.PROCESS_LOCAL) + + constraint match { + case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + val retval = TaskLocality.withName(str) + // Must not specify PROCESS_LOCAL ! + assert (retval != TaskLocality.PROCESS_LOCAL) + + retval + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); + // default to preserve earlier behavior + NODE_LOCAL + } + } + } +} + +/** + * Schedules the tasks within a single TaskSet in the ClusterScheduler. + */ +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet) + extends TaskSetManager + with Logging { + + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + + // Serializer for closures and tasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + + val tasks = taskSet.tasks + val numTasks = tasks.length + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + var tasksFinished = 0 + + var weight = 1 + var minShare = 0 + var runningTasks = 0 + var priority = taskSet.priority + var stageId = taskSet.stageId + var name = "TaskSet_"+taskSet.stageId.toString + var parent:Schedulable = null + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node (process local to container). These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing pending tasks with no locality preferences + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Tasks that can be speculated. Since these will be a small fraction of total + // tasks, we'll just hold them in a HashSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[Long, TaskInfo] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // How frequently to reprint duplicate exceptions in full, in milliseconds + val EXCEPTION_PRINT_INTERVAL = + System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong + // Map of recent exceptions (identified by string representation and + // top stack frame) to duplicate count (how many times the same + // exception has appeared) and time the full exception was + // printed. This should ideally be an LRU map that can drop old + // exceptions automatically. + val recentExceptions = HashMap[String, (Int, Long)]() + + // Figure out the current map output tracker generation and set it on all tasks + val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) + for (t <- tasks) { + t.generation = generation + } + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Note that it follows the hierarchy. + // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and + // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = { + + if (TaskLocality.PROCESS_LOCAL == taskLocality) { + // straight forward comparison ! Special case it. + val retval = new HashSet[String]() + scheduler.synchronized { + for (location <- _taskPreferredLocations) { + if (scheduler.isExecutorAliveOnHostPort(location)) { + retval += location + } + } + } + + return retval + } + + val taskPreferredLocations = + if (TaskLocality.NODE_LOCAL == taskLocality) { + _taskPreferredLocations + } else { + assert (TaskLocality.RACK_LOCAL == taskLocality) + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new HashSet[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (processLocalLocations.size == 0) + assert (hostLocalLocations.size == 0) + pendingTasksWithNoPrefs += index + } else { + + // process local locality + for (hostPort <- processLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + } + + // host locality (includes process local) + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality (includes process local and host local) + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) + list += index + } + } + + allPendingTasks += index + } + + // Return the pending tasks list for a given host port (process local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host Port (which would be process local) + def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + // This method also cleans up any tasks in the list that have already + // been launched, since we want that to happen lazily. + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (copiesRunning(index) == 0 && !finished(index)) { + return Some(index) + } + } + return None + } + + // Return a speculative task for a given host if any are available. The task should not have an + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask + } + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) + if (processLocalTask != None) { + return processLocalTask + } + + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) + if (localTask != None) { + return localTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). + val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) + if (noPrefTask != None) { + return noPrefTask + } + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + val nonLocalTask = findTaskFromList(allPendingTasks) + if (nonLocalTask != None) { + return nonLocalTask + } + } + + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(hostPort, locality) + } + + private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { + Utils.checkHostPort(hostPort) + + val locs = task.preferredLocations + + locs.contains(hostPort) + } + + private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { + val locs = task.preferredLocations + + // If no preference, consider it as host local + if (locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY + } + + findTask(hostPort, locality) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val taskLocality = + if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else + if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else + TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) + // Do various bookkeeping + copiesRunning(index) += 1 + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) + taskInfos(taskId) = info + taskAttempts(index) = info :: taskAttempts(index) + if (TaskLocality.NODE_LOCAL == taskLocality) { + lastPreferredLaunchTime = time + } + // Serialize and return the task + val startTime = System.currentTimeMillis + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) + val timeTaken = System.currentTimeMillis - startTime + increaseRunningTasks(1) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) + val taskName = "task %s:%d".format(taskSet.id, index) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskFinished(tid, state, serializedData) + case TaskState.LOST => + taskLost(tid, state, serializedData) + case TaskState.FAILED => + taskLost(tid, state, serializedData) + case TaskState.KILLED => + taskLost(tid, state, serializedData) + case _ => + } + } + + def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + decreaseRunningTasks(1) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + decreaseRunningTasks(1) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + return + + case ef: ExceptionFailure => + val key = ef.description + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(message: String) { + // Save the error message + abort("Error: " + message) + } + + def abort(message: String) { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) + decreaseRunningTasks(runningTasks) + sched.taskSetFinished(this) + } + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def addSchedulable(schedulable:Schedulable) { + //nothing + } + + override def removeSchedulable(schedulable:Schedulable) { + //nothing + } + + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + override def executorLost(execId: String, hostPort: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + + // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to + // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. + // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if + // there is no host local node for the task (not if there is no process local node for the task) + for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { + // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } + } + + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.executorId == execId) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + } + } + } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { + taskLost(tid, TaskState.KILLED, null) + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the ClusterScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + override def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = System.currentTimeMillis() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo( + "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.hostPort, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks + } + + override def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } +} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 1c403ef323..2b5a74d4e5 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,734 +1,17 @@ package spark.scheduler.cluster -import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} - import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark._ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer - -private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { - - // process local is expected to be used ONLY within tasksetmanager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value - - type TaskLocality = Value - - def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { - - // Must not be the constraint. - assert (constraint != TaskLocality.PROCESS_LOCAL) - - constraint match { - case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL - case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL - // For anything else, allow - case _ => true - } - } - - def parse(str: String): TaskLocality = { - // better way to do this ? - try { - val retval = TaskLocality.withName(str) - // Must not specify PROCESS_LOCAL ! - assert (retval != TaskLocality.PROCESS_LOCAL) - - retval - } catch { - case nEx: NoSuchElementException => { - logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); - // default to preserve earlier behavior - NODE_LOCAL - } - } - } -} - /** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Schedulable - with Logging { - - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = 4 - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - var weight = 1 - var minShare = 0 - var runningTasks = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent:Schedulable = null - - // Last time when we launched a preferred task (for delay scheduling) - var lastPreferredLaunchTime = System.currentTimeMillis - - // List of pending tasks for each node (process local to container). These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node. - // Essentially, similar to pendingTasksForHostPort, except at host level - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List of pending tasks for each node based on rack locality. - // Essentially, similar to pendingTasksForHost, except at rack level - private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // List containing pending tasks with no locality preferences - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // List containing all pending tasks (also used as a stack, as above) - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the job fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - // Map of recent exceptions (identified by string representation and - // top stack frame) to duplicate count (how many times the same - // exception has appeared) and time the full exception was - // printed. This should ideally be an LRU map that can drop old - // exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker generation and set it on all tasks - val generation = sched.mapOutputTracker.getGeneration - logDebug("Generation for " + taskSet.id + ": " + generation) - for (t <- tasks) { - t.generation = generation - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Note that it follows the hierarchy. - // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and - // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL - private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, - taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - - if (TaskLocality.PROCESS_LOCAL == taskLocality) { - // straight forward comparison ! Special case it. - val retval = new HashSet[String]() - scheduler.synchronized { - for (location <- _taskPreferredLocations) { - if (scheduler.isExecutorAliveOnHostPort(location)) { - retval += location - } - } - } - - return retval - } - - val taskPreferredLocations = - if (TaskLocality.NODE_LOCAL == taskLocality) { - _taskPreferredLocations - } else { - assert (TaskLocality.RACK_LOCAL == taskLocality) - // Expand set to include all 'seen' rack local hosts. - // This works since container allocation/management happens within master - so any rack locality information is updated in msater. - // Best case effort, and maybe sort of kludge for now ... rework it later ? - val hosts = new HashSet[String] - _taskPreferredLocations.foreach(h => { - val rackOpt = scheduler.getRackForHost(h) - if (rackOpt.isDefined) { - val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) - if (hostsOpt.isDefined) { - hosts ++= hostsOpt.get - } - } - - // Ensure that irrespective of what scheduler says, host is always added ! - hosts += h - }) - - hosts - } - - val retval = new HashSet[String] - scheduler.synchronized { - for (prefLocation <- taskPreferredLocations) { - val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1) - if (aliveLocationsOpt.isDefined) { - retval ++= aliveLocationsOpt.get - } - } - } - - retval - } - - // Add a task to all the pending-task lists that it should be on. - private def addPendingTask(index: Int) { - // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate - // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - - if (rackLocalLocations.size == 0) { - // Current impl ensures this. - assert (processLocalLocations.size == 0) - assert (hostLocalLocations.size == 0) - pendingTasksWithNoPrefs += index - } else { - - // process local locality - for (hostPort <- processLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) - hostPortList += index - } - - // host locality (includes process local) - for (hostPort <- hostLocalLocations) { - // DEBUG Code - Utils.checkHostPort(hostPort) - - val host = Utils.parseHostPort(hostPort)._1 - val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) - hostList += index - } - - // rack locality (includes process local and host local) - for (rackLocalHostPort <- rackLocalLocations) { - // DEBUG Code - Utils.checkHostPort(rackLocalHostPort) - - val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 - val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) - list += index - } - } - - allPendingTasks += index - } - - // Return the pending tasks list for a given host port (process local), or an empty list if - // there is no map entry for that host - private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { - // DEBUG Code - Utils.checkHostPort(hostPort) - pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) - } - - // Return the pending tasks list for a given host, or an empty list if - // there is no map entry for that host - private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Return the pending tasks (rack level) list for a given host, or an empty list if - // there is no map entry for that host - private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { - val host = Utils.parseHostPort(hostPort)._1 - pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) - } - - // Number of pending tasks for a given host Port (which would be process local) - def numPendingTasksForHostPort(hostPort: String): Int = { - getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending tasks for a given host (which would be data local) - def numPendingTasksForHost(hostPort: String): Int = { - getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - // Number of pending rack local tasks for a given host - def numRackLocalPendingTasksForHost(hostPort: String): Int = { - getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) - } - - - // Dequeue a pending task from the given list and return its index. - // Return None if the list is empty. - // This method also cleans up any tasks in the list that have already - // been launched, since we want that to happen lazily. - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if locality is set, the - // task must have a preference for this host/rack/no preferred locations at all. - private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - - assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - - if (speculatableTasks.size > 0) { - val localTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) - } - - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - - // check for rack locality - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - locations.contains(hostPort) && !attemptLocs.contains(hostPort) - } - - if (rackTask != None) { - speculatableTasks -= rackTask.get - return rackTask - } - } - - // Any task ... - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - // Check for attemptLocs also ? - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask - } - } - } - return None - } - - // Dequeue a pending task for a given node and return its index. - // If localOnly is set to false, allow non-local tasks as well. - private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { - val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort)) - if (processLocalTask != None) { - return processLocalTask - } - - val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) - if (localTask != None) { - return localTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) - if (rackLocalTask != None) { - return rackLocalTask - } - } - - // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. - // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). - val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) - if (noPrefTask != None) { - return noPrefTask - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - val nonLocalTask = findTaskFromList(allPendingTasks) - if (nonLocalTask != None) { - return nonLocalTask - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(hostPort, locality) - } - - private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = { - Utils.checkHostPort(hostPort) - - val locs = task.preferredLocations - - locs.contains(hostPort) - } - - private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = { - val locs = task.preferredLocations - - // If no preference, consider it as host local - if (locs.isEmpty) return true - - val host = Utils.parseHostPort(hostPort)._1 - locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined - } - - // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). - // This is true if either the task has preferred locations and this host is one, or it has - // no preferred locations (in which we still count the launch as preferred). - private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { - - val locs = task.preferredLocations - - val preferredRacks = new HashSet[String]() - for (preferredHost <- locs) { - val rack = sched.getRackForHost(preferredHost) - if (None != rack) preferredRacks += rack.get - } - - if (preferredRacks.isEmpty) return false - - val hostRack = sched.getRackForHost(hostPort) - - return None != hostRack && preferredRacks.contains(hostRack.get) - } - - // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - // If explicitly specified, use that - val locality = if (overrideLocality != null) overrideLocality else { - // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY - } - - findTask(hostPort, locality) match { - case Some(index) => { - // Found a task; do some bookkeeping and return a Mesos task for it - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - val taskLocality = - if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else - if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else - if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else - TaskLocality.ANY - val prefStr = taskLocality.toString - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, execId, hostPort, prefStr)) - // Do various bookkeeping - copiesRunning(index) += 1 - val time = System.currentTimeMillis - val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - if (TaskLocality.NODE_LOCAL == taskLocality) { - lastPreferredLaunchTime = time - } - // Serialize and return the task - val startTime = System.currentTimeMillis - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis - startTime - increaseRunningTasks(1) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) - } - case _ => - } - } - return None - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - decreaseRunningTasks(1) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - decreaseRunningTasks(1) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - return - - case ef: ExceptionFailure => - val key = ef.description - val now = System.currentTimeMillis - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - decreaseRunningTasks(runningTasks) - sched.taskSetFinished(this) - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable:Schedulable) { - //nothing - } - - override def removeSchedulable(schedulable:Schedulable) { - //nothing - } - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def executorLost(execId: String, hostPort: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // If some task has preferred locations only on hostname, and there are no more executors there, - // put it in the no-prefs list to avoid the wait from delay scheduling - - // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to - // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no process local node for the task) - for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { - // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.hostPort, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksFinished < numTasks - } +private[spark] trait TaskSetManager extends Schedulable { + def taskSet: TaskSet + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + def numPendingTasksForHostPort(hostPort: String): Int + def numRackLocalPendingTasksForHost(hostPort :String): Int + def numPendingTasksForHost(hostPort: String): Int + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + def error(message: String) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 37a67f9b1b..664dc9e886 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -2,19 +2,215 @@ package spark.scheduler.local import java.io.File import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import spark._ +import spark.TaskState.TaskState import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.{TaskLocality, TaskInfo} +import spark.scheduler.cluster._ +import akka.actor._ /** * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + +private[spark] case class LocalReviveOffers() +private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { + def receive = { + case LocalReviveOffers => + logInfo("LocalReviveOffers") + launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => + logInfo("LocalStatusUpdate") + freeCores += 1 + localScheduler.statusUpdate(taskId, state, serializeData) + launchTask(localScheduler.resourceOffer(freeCores)) + } + + def launchTask(tasks : Seq[TaskDescription]) { + for (task <- tasks) + { + freeCores -= 1 + localScheduler.threadPool.submit(new Runnable { + def run() { + localScheduler.runTask(task.taskId,task.serializedTask) + } + }) + } + } +} + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + Thread.currentThread().setContextClassLoader(sched.classLoader) + SparkEnv.set(sched.env) + logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + logInfo(taskSet.tasks(index).toString) + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + logInfo("taskId:%d,task:%s".format(index,task)) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > 4) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + //logError(errorMessage) + //sched.listener.taskEnded(task, reason, null, null, info, null) + sched.listener.taskSetFailed(taskSet, errorMessage) + sched.taskSetFinished(this) + decreaseRunningTasks(runningTasks) + } + } + } + + def error(message: String) { + } +} + +private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -30,90 +226,126 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) + var schedulableBuilder: SchedulableBuilder = null + var rootPool: Pool = null + val activeTaskSets = new HashMap[String, TaskSetManager] + val taskIdToTaskSetId = new HashMap[Long, String] + val taskSetTaskIds = new HashMap[String, HashSet[Long]] + + var localActor: ActorRef = null // TODO: Need to take into account stage priority in scheduling - override def start() { } + override def start() { + //default scheduler is FIFO + val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO") + //temporarily set rootPool name to empty + rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0) + schedulableBuilder = { + schedulingMode match { + case "FIFO" => + new FIFOSchedulableBuilder(rootPool) + case "FAIR" => + new FairSchedulableBuilder(rootPool) + } + } + schedulableBuilder.buildPools() + + //val properties = new ArrayBuffer[(String, String)] + localActor = env.actorSystem.actorOf( + Props(new LocalActor(this, threads)), "Test") + } override def setListener(listener: TaskSchedulerListener) { this.listener = listener } override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - val failCount = new Array[Int](tasks.size) + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers + } - def submitTask(task: Task[_], idInJob: Int) { - val myAttemptId = attemptId.getAndIncrement() - threadPool.submit(new Runnable { - def run() { - runTask(task, idInJob, myAttemptId) - } - }) + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) } - def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) - logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserStart = System.currentTimeMillis() - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - val deserTime = System.currentTimeMillis() - deserStart - - // Run it - val result: Any = deserializedTask.run(attemptId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - logInfo("Finished " + task) - info.markSuccessful() - deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - - // If the threadpool has not already been shutdown, notify DAGScheduler - if (!Thread.currentThread().isInterrupted) - listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null)) - } catch { - case t: Throwable => { - logError("Exception in task " + idInJob, t) - failCount.synchronized { - failCount(idInJob) += 1 - if (failCount(idInJob) <= maxFailures) { - submitTask(task, idInJob) - } else { - // TODO: Do something nicer here to return all the way to the user - if (!Thread.currentThread().isInterrupted) { - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) - listener.taskEnded(task, failure, null, null, info, null) - } - } + var launchTask = false + for (manager <- sortedTaskSetQueue) { + do { + launchTask = false + logInfo("freeCores is" + freeCpuCores) + manager.slaveOffer(null,null,freeCpuCores) match { + case Some(task) => + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true + case None => {} } - } - } + } while(launchTask) } + return tasks + } - for ((task, i) <- tasks.zipWithIndex) { - submitTask(task, i) - } + def taskSetFinished(manager: TaskSetManager) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } + + def runTask(taskId: Long, bytes: ByteBuffer) { + logInfo("Running " + taskId) + val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) + val ser = SparkEnv.get.closureSerializer.newInstance() + try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; + // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile + val deserStart = System.currentTimeMillis() + val deserializedTask = ser.deserialize[Task[_]]( + taskBytes, Thread.currentThread.getContextClassLoader) + val deserTime = System.currentTimeMillis() - deserStart + + // Run it + val result: Any = deserializedTask.run(taskId) + + // Serialize and deserialize the result to emulate what the Mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val serResult = ser.serialize(result) + deserializedTask.metrics.get.resultSize = serResult.limit() + val resultToReturn = ser.deserialize[Any](serResult) + val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( + ser.serialize(Accumulators.values)) + logInfo("Finished " + taskId) + deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough + deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt + + val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) + val serializedResult = ser.serialize(taskResult) + localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) + } catch { + case t: Throwable => { + val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) + localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) + } + } } /** @@ -128,6 +360,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) @@ -143,7 +376,14 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } } - override def stop() { + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) + { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetManager.statusUpdate(taskId, state, serializedData) + } + + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index a39418b716..e6ad90192e 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -16,7 +16,7 @@ class DummyTaskSetManager( initNumTasks: Int, clusterScheduler: ClusterScheduler, taskSet: TaskSet) - extends TaskSetManager(clusterScheduler,taskSet) { + extends ClusterTaskSetManager(clusterScheduler,taskSet) { parent = null weight = 1 -- cgit v1.2.3 From c3db3ea55467c3fb053453c8c567db357d939640 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Thu, 30 May 2013 20:49:40 +0800 Subject: 1. Add unit test for local scheduler 2. Move localTaskSetManager to a new file --- .../spark/scheduler/local/LocalScheduler.scala | 241 ++++----------------- .../scheduler/local/LocalTaskSetManager.scala | 173 +++++++++++++++ .../spark/scheduler/LocalSchedulerSuite.scala | 171 +++++++++++++++ 3 files changed, 385 insertions(+), 200 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala create mode 100644 core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 664dc9e886..69dacfc2bd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -15,7 +15,7 @@ import spark.scheduler.cluster._ import akka.actor._ /** - * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ @@ -26,10 +26,8 @@ private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, seri private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { def receive = { case LocalReviveOffers => - logInfo("LocalReviveOffers") launchTask(localScheduler.resourceOffer(freeCores)) case LocalStatusUpdate(taskId, state, serializeData) => - logInfo("LocalStatusUpdate") freeCores += 1 localScheduler.statusUpdate(taskId, state, serializeData) launchTask(localScheduler.resourceOffer(freeCores)) @@ -48,168 +46,6 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I } } -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { - var parent: Schedulable = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_"+taskSet.stageId.toString - - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val ser = SparkEnv.get.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - def addSchedulable(schedulable: Schedulable): Unit = { - //nothing - } - - def removeSchedulable(schedulable: Schedulable): Unit = { - //nothing - } - - def getSchedulableByName(name: String): Schedulable = { - return null - } - - def executorLost(executorId: String, host: String): Unit = { - //nothing - } - - def checkSpeculatableTasks(): Boolean = { - return true - } - - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - def hasPendingTasks(): Boolean = { - return true - } - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - Thread.currentThread().setContextClassLoader(sched.classLoader) - SparkEnv.set(sched.env) - logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - logInfo(taskSet.tasks(index).toString) - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - logInfo("taskId:%d,task:%s".format(index,task)) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - return Some(new TaskDescription(taskId, null, taskName, bytes)) - case None => {} - } - } - return None - } - - def numPendingTasksForHostPort(hostPort: String): Int = { - return 0 - } - - def numRackLocalPendingTasksForHost(hostPort :String): Int = { - return 0 - } - - def numPendingTasksForHost(hostPort: String): Int = { - return 0 - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > 4) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) - //logError(errorMessage) - //sched.listener.taskEnded(task, reason, null, null, info, null) - sched.listener.taskSetFailed(taskSet, errorMessage) - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - } - } - } - - def error(message: String) { - } -} - private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with Logging { @@ -233,7 +69,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val taskSetTaskIds = new HashMap[String, HashSet[Long]] var localActor: ActorRef = null - // TODO: Need to take into account stage priority in scheduling override def start() { //default scheduler is FIFO @@ -250,7 +85,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } schedulableBuilder.buildPools() - //val properties = new ArrayBuffer[(String, String)] localActor = env.actorSystem.actorOf( Props(new LocalActor(this, threads)), "Test") } @@ -260,51 +94,56 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } override def submitTasks(taskSet: TaskSet) { - var manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers + synchronized { + var manager = new LocalTaskSetManager(this, taskSet) + schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) + activeTaskSets(taskSet.id) = manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + localActor ! LocalReviveOffers + } } def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) - } + synchronized { + var freeCpuCores = freeCores + val tasks = new ArrayBuffer[TaskDescription](freeCores) + val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() + for (manager <- sortedTaskSetQueue) { + logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + } - var launchTask = false - for (manager <- sortedTaskSetQueue) { + var launchTask = false + for (manager <- sortedTaskSetQueue) { do { launchTask = false - logInfo("freeCores is" + freeCpuCores) manager.slaveOffer(null,null,freeCpuCores) match { case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true + tasks += task + taskIdToTaskSetId(task.taskId) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += task.taskId + freeCpuCores -= 1 + launchTask = true case None => {} - } + } } while(launchTask) + } + return tasks } - return tasks } def taskSetFinished(manager: TaskSetManager) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id + synchronized { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds -= manager.taskSet.id + } } def runTask(taskId: Long, bytes: ByteBuffer) { logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) val ser = SparkEnv.get.closureSerializer.newInstance() @@ -344,8 +183,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: case t: Throwable => { val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace) localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } } + } } /** @@ -376,11 +215,13 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) - { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetManager.statusUpdate(taskId, state, serializedData) + def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { + synchronized { + val taskSetId = taskIdToTaskSetId(taskId) + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + taskSetManager.statusUpdate(taskId, state, serializedData) + } } override def stop() { diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala new file mode 100644 index 0000000000..f2e07d162a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -0,0 +1,173 @@ +package spark.scheduler.local + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { + var parent: Schedulable = null + var weight: Int = 1 + var minShare: Int = 0 + var runningTasks: Int = 0 + var priority: Int = taskSet.priority + var stageId: Int = taskSet.stageId + var name: String = "TaskSet_"+taskSet.stageId.toString + + + var failCount = new Array[Int](taskSet.tasks.size) + val taskInfos = new HashMap[Long, TaskInfo] + val numTasks = taskSet.tasks.size + var numFinished = 0 + val ser = SparkEnv.get.closureSerializer.newInstance() + val copiesRunning = new Array[Int](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val MAX_TASK_FAILURES = sched.maxFailures + + def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + def decreaseRunningTasks(taskNum: Int): Unit = { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + def addSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def removeSchedulable(schedulable: Schedulable): Unit = { + //nothing + } + + def getSchedulableByName(name: String): Schedulable = { + return null + } + + def executorLost(executorId: String, host: String): Unit = { + //nothing + } + + def checkSpeculatableTasks(): Boolean = { + return true + } + + def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] + sortedTaskSetQueue += this + return sortedTaskSetQueue + } + + def hasPendingTasks(): Boolean = { + return true + } + + def findTask(): Option[Int] = { + for (i <- 0 to numTasks-1) { + if (copiesRunning(i) == 0 && !finished(i)) { + return Some(i) + } + } + return None + } + + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + SparkEnv.set(sched.env) + logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + if (availableCpus > 0 && numFinished < numTasks) { + findTask() match { + case Some(index) => + logInfo(taskSet.tasks(index).toString) + val taskId = sched.attemptId.getAndIncrement() + val task = taskSet.tasks(index) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + taskInfos(taskId) = info + val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") + val taskName = "task %s:%d".format(taskSet.id, index) + copiesRunning(index) += 1 + increaseRunningTasks(1) + return Some(new TaskDescription(taskId, null, taskName, bytes)) + case None => {} + } + } + return None + } + + def numPendingTasksForHostPort(hostPort: String): Int = { + return 0 + } + + def numRackLocalPendingTasksForHost(hostPort :String): Int = { + return 0 + } + + def numPendingTasksForHost(hostPort: String): Int = { + return 0 + } + + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + state match { + case TaskState.FINISHED => + taskEnded(tid, state, serializedData) + case TaskState.FAILED => + taskFailed(tid, state, serializedData) + case _ => {} + } + } + + def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markSuccessful() + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + numFinished += 1 + decreaseRunningTasks(1) + finished(index) = true + if (numFinished == numTasks) { + sched.taskSetFinished(this) + } + } + + def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + val index = info.index + val task = taskSet.tasks(index) + info.markFailed() + decreaseRunningTasks(1) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + if (!finished(index)) { + copiesRunning(index) -= 1 + numFailures(index) += 1 + val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + if (numFailures(index) > MAX_TASK_FAILURES) { + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + decreaseRunningTasks(runningTasks) + sched.listener.taskSetFailed(taskSet, errorMessage) + // need to delete failed Taskset from schedule queue + sched.taskSetFinished(this) + } + } + } + + def error(message: String) { + } +} diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..37d14ed113 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -0,0 +1,171 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() +} + + +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + new Thread { + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToLock(number).jobWait() + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + TaskThreadInfo.threadToRunning(threadIndex) = false + } + }.start() + Thread.sleep(2000) + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + createThread(2,null,sc,sem) + createThread(3,null,sc,sem) + createThread(4,null,sc,sem) + createThread(5,null,sc,sem) + createThread(6,null,sc,sem) + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + createThread(20,"2",sc,sem) + createThread(30,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + createThread(21,"2",sc,sem) + createThread(31,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + createThread(22,"2",sc,sem) + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + Thread.sleep(1000) + assert(TaskThreadInfo.threadToRunning(32) === true) + + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + + TaskThreadInfo.threadToLock(11).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} -- cgit v1.2.3 From bed1b08169df91e97cb9ebaf8e58daeb655ff55d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 May 2013 16:21:49 -0700 Subject: Do not create symlink for local add file. Instead, copy the file. This prevents Spark from changing the original file's permission, and also allow add file to work on non-posix operating systems. --- core/src/main/scala/spark/Utils.scala | 78 +++++++++++++++++------------------ 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 84626df553..ec15326014 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -4,20 +4,26 @@ import java.io._ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} +import java.util.regex.Pattern + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import scala.io.Source + import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder + +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} + import spark.serializer.SerializerInstance import spark.deploy.SparkHadoopUtil -import java.util.regex.Pattern + /** * Various utility methods used by Spark. */ private object Utils extends Logging { + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -68,7 +74,6 @@ private object Utils extends Logging { return buf } - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook @@ -87,19 +92,19 @@ private object Utils extends Logging { } } - // Note: if file is child of some registered path, while not equal to it, then return true; else false - // This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException - // and incomplete cleanup + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in IOException and incomplete cleanup. def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined + shutdownDeletePaths.find { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + }.isDefined + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") } - - if (retval) logInfo("path = " + file + ", already present as root for deletion.") - retval } @@ -131,7 +136,7 @@ private object Utils extends Logging { if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) - return dir + dir } /** Copy all data from an InputStream to an OutputStream */ @@ -174,35 +179,30 @@ private object Utils extends Logging { Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { tempFile.delete() - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) } else { Files.move(tempFile, targetFile) } case "file" | null => - val sourceFile = if (uri.isAbsolute) { - new File(uri) - } else { - new File(url) - } - if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) - } else { - // Remove the file if it already exists - targetFile.delete() - // Symlink the file locally. - if (uri.isAbsolute) { - // url is absolute, i.e. it starts with "file:///". Extract the source - // file's absolute path from the url. - val sourceFile = new File(uri) - logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + // In the case of a local file, copy the local file to the target directory. + // Note the difference between uri vs url. + val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) + if (targetFile.exists) { + // If the target file already exists, warn the user if + if (!Files.equal(sourceFile, targetFile)) { + throw new SparkException( + "File " + targetFile + " exists and does not match contents of" + " " + url) } else { - // url is not absolute, i.e. itself is the path to the source file. - logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(url, targetFile.getAbsolutePath) + // Do nothing if the file contents are the same, i.e. this file has been copied + // previously. + logInfo(sourceFile.getAbsolutePath + " has been previously copied to " + + targetFile.getAbsolutePath) } + } else { + // The file does not exist in the target directory. Copy it there. + logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + Files.copy(sourceFile, targetFile) } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others @@ -323,8 +323,6 @@ private object Utils extends Logging { InetAddress.getByName(address).getHostName } - - def localHostPort(): String = { val retval = System.getProperty("spark.hostPort", null) if (retval == null) { @@ -382,6 +380,7 @@ private object Utils extends Logging { // Typically, this will be of order of number of nodes in cluster // If not, we should change it to LRUCache or something. private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + def parseHostPort(hostPort: String): (String, Int) = { { // Check cache first. @@ -390,7 +389,8 @@ private object Utils extends Logging { } val indx: Int = hostPort.lastIndexOf(':') - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... + // but then hadoop does not support ipv6 right now. // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 if (-1 == indx) { val retval = (hostPort, 0) -- cgit v1.2.3 From f6ad3781b1d9a044789f114d13787b9d05223da3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 May 2013 16:28:08 -0700 Subject: Fixed the flaky unpersist test in RDDSuite. --- core/src/test/scala/spark/RDDSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index a761dd77c5..3f69e99780 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -106,9 +106,9 @@ class RDDSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() rdd.count - assert(sc.persistentRdds.isEmpty == false) + assert(sc.persistentRdds.isEmpty === false) rdd.unpersist() - assert(sc.persistentRdds.isEmpty == true) + assert(sc.persistentRdds.isEmpty === true) failAfter(Span(3000, Millis)) { try { @@ -116,12 +116,12 @@ class RDDSuite extends FunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case e: Exception => + case _ => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } } - assert(sc.getRDDStorageInfo.isEmpty == true) + assert(sc.getRDDStorageInfo.isEmpty === true) } test("caching with failures") { -- cgit v1.2.3 From 926f41cc522def181c167b71dc919a0759c5d3f6 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 30 May 2013 17:55:11 +0800 Subject: fix block manager UI display issue when enable spark.cleaner.ttl --- core/src/main/scala/spark/storage/StorageUtils.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 8f52168c24..81e607868d 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -64,12 +64,12 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rdd = sc.persistentRdds(rddId) - val rddName = Option(rdd.name).getOrElse(rddKey) - val rddStorageLevel = rdd.getStorageLevel - - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize) - }.toArray + sc.persistentRdds.get(rddId).map { r => + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + } + }.flatMap(x => x).toArray scala.util.Sorting.quickSort(rddInfos) -- cgit v1.2.3 From ba5e544461e8ca9216af703033f6b0de6dbc56ec Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 May 2013 01:48:16 -0700 Subject: More block manager cleanup. Implemented a removeRdd method in BlockManager, and use that to implement RDD.unpersist. Previously, unpersist needs to send B akka messages, where B = number of blocks. Now unpersist only needs to send W akka messages, where W = the number of workers. --- core/src/main/scala/spark/RDD.scala | 21 +-- .../main/scala/spark/storage/BlockManager.scala | 31 +++- .../scala/spark/storage/BlockManagerMaster.scala | 49 +++--- .../spark/storage/BlockManagerMasterActor.scala | 192 ++++++++++----------- .../scala/spark/storage/BlockManagerMessages.scala | 6 + .../spark/storage/BlockManagerSlaveActor.scala | 8 +- .../scala/spark/storage/BlockManagerWorker.scala | 10 +- .../scala/spark/storage/BlockManagerSuite.scala | 36 ++-- 8 files changed, 187 insertions(+), 166 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dde131696f..e6c0438d76 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,13 +1,10 @@ package spark -import java.net.URL -import java.util.{Date, Random} -import java.util.{HashMap => JHashMap} +import java.util.Random import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -32,7 +29,6 @@ import spark.rdd.MapPartitionsWithIndexRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD import spark.rdd.ShuffledRDD -import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.rdd.ZippedPartitionsRDD2 @@ -141,10 +137,15 @@ abstract class RDD[T: ClassManifest]( /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() - /** Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. */ - def unpersist(): RDD[T] = { + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + * @return This RDD. + */ + def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.env.blockManager.master.removeRdd(id) + sc.env.blockManager.master.removeRdd(id, blocking) sc.persistentRdds.remove(id) storageLevel = StorageLevel.NONE this @@ -269,8 +270,8 @@ abstract class RDD[T: ClassManifest]( def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { var fraction = 0.0 var total = 0 - var multiplier = 3.0 - var initialCount = count() + val multiplier = 3.0 + val initialCount = count() var maxSelected = 0 if (initialCount > Integer.MAX_VALUE - 1) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index d35c43f194..3a5d4ef448 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -3,8 +3,7 @@ package spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} -import scala.collection.JavaConversions._ +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} @@ -15,7 +14,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -95,9 +94,11 @@ private[spark] class BlockManager( new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean - private val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - private val nettyPort = if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + private val nettyPort: Int = { + val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean + val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt + if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + } val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext @@ -824,10 +825,24 @@ private[spark] class BlockManager( } } + /** + * Remove all blocks belonging to the given RDD. + * @return The number of blocks removed. + */ + def removeRdd(rddId: Int): Int = { + // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps + // from RDD.id to blocks. + logInfo("Removing RDD " + rddId) + val rddPrefix = "rdd_" + rddId + "_" + val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) + blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + blocksToRemove.size + } + /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: String) { + def removeBlock(blockId: String, tellMaster: Boolean = true) { logInfo("Removing block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { @@ -839,7 +854,7 @@ private[spark] class BlockManager( "the disk or memory store") } blockInfo.remove(blockId) - if (info.tellMaster) { + if (tellMaster && info.tellMaster) { reportBlockStatus(blockId, info) } } else { diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index ac26c16867..7099e40618 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -1,19 +1,11 @@ package spark.storage -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.util.Random - -import akka.actor.{Actor, ActorRef, ActorSystem, Props} -import akka.dispatch.Await +import akka.actor.ActorRef +import akka.dispatch.{Await, Future} import akka.pattern.ask -import akka.util.{Duration, Timeout} -import akka.util.duration._ +import akka.util.Duration -import spark.{Logging, SparkException, Utils} +import spark.{Logging, SparkException} private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { @@ -91,15 +83,28 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi /** * Remove all blocks belonging to the given RDD. */ - def removeRdd(rddId: Int) { - val rddBlockPrefix = "rdd_" + rddId + "_" - // Get the list of blocks in block manager, and remove ones that are part of this RDD. - // The runtime complexity is linear to the number of blocks persisted in the cluster. - // It could be expensive if the cluster is large and has a lot of blocks persisted. - getStorageStatus.flatMap(_.blocks).foreach { case(blockId, status) => - if (blockId.startsWith(rddBlockPrefix)) { - removeBlock(blockId) - } + def removeRdd(rddId: Int, blocking: Boolean) { + // The logic to remove an RDD is somewhat complicated: + // 1. Send BlockManagerMasterActor a RemoveRdd message. + // 2. Upon receiving the RemoveRdd message, BlockManagerMasterActor will forward the message + // to all workers to remove blocks belonging to the RDD, and return a Future for the results. + // 3. The Future is sent back here, and on successful completion of the Future, this function + // sends a RemoveRddMetaData message to BlockManagerMasterActor. + // 4. Upon receiving the RemoveRddMetaData message, BlockManagerMasterActor will delete the meta + // data for the given RDD. + // + // The reason we are doing it this way is to reduce the amount of messages the driver sends. + // The number of messages that need to be sent is only the number of workers the cluster has, + // rather than the number of blocks in the cluster. Note that we can further reduce the number + // of messages by tracking for a given RDD, where are its blocks. Then we can send only to the + // workers that have the given RDD. But this remains future work. + val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + future onComplete { + case Left(throwable) => logError("Failed to remove RDD " + rddId, throwable) + case Right(numBlocks) => tell(RemoveRddMetaData(rddId, numBlocks.sum)) + } + if (blocking) { + Await.result(future, timeout) } } @@ -114,7 +119,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } /** Stop the driver actor, called only on the Spark driver node */ diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 9b64f95df8..00aa97bf78 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -2,15 +2,16 @@ package spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.util.Random import akka.actor.{Actor, ActorRef, Cancellable} -import akka.util.{Duration, Timeout} +import akka.dispatch.Future +import akka.pattern.ask +import akka.util.Duration import akka.util.duration._ -import spark.{Logging, Utils} +import spark.{Logging, Utils, SparkException} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -21,13 +22,16 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = - new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId] + private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + + val akkaTimeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") initLogging() @@ -50,28 +54,38 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { def receive = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) + sender ! true case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + // TODO: Ideally we want to handle all the message replies in receive instead of in the + // individual private methods. updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) case GetLocations(blockId) => - getLocations(blockId) + sender ! getLocations(blockId) case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) + sender ! getLocationsMultipleBlockIds(blockIds) case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ + sender ! getPeers(blockManagerId, size) case GetMemoryStatus => - getMemoryStatus + sender ! memoryStatus case GetStorageStatus => - getStorageStatus + sender ! storageStatus + + case RemoveRdd(rddId) => + sender ! removeRdd(rddId) + + case RemoveRddMetaData(rddId, numBlocks) => + removeRddMetaData(rddId, numBlocks) + sender ! true case RemoveBlock(blockId) => - removeBlock(blockId) + removeBlockFromWorkers(blockId) + sender ! true case RemoveExecutor(execId) => removeExecutor(execId) @@ -81,7 +95,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel + timeoutCheckingTask.cancel() } context.stop(self) @@ -89,13 +103,34 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { expireDeadHosts() case HeartBeat(blockManagerId) => - heartBeat(blockManagerId) + sender ! heartBeat(blockManagerId) case other => - logInfo("Got unknown message: " + other) + logWarning("Got unknown message: " + other) + } + + private def removeRdd(rddId: Int): Future[Seq[Int]] = { + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + import context.dispatcher + Future.sequence(blockManagerInfo.values.map { bm => + bm.slaveActor.ask(RemoveRdd(rddId))(akkaTimeout).mapTo[Int] + }.toSeq) + } + + private def removeRddMetaData(rddId: Int, numBlocks: Int) { + val prefix = "rdd_" + rddId + "_" + // Find all blocks for the given RDD, remove the block from both blockLocations and + // the blockManagerInfo that is tracking the blocks. + val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + blocks.foreach { blockId => + val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) + bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) + blockLocations.remove(blockId) + } } - def removeBlockManager(blockManagerId: BlockManagerId) { + private def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) // Remove the block manager from blockManagerIdByExecutor. @@ -106,7 +141,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next - val locations = blockLocations.get(blockId)._2 + val locations = blockLocations.get(blockId) locations -= blockManagerId if (locations.size == 0) { blockLocations.remove(locations) @@ -114,11 +149,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } - def expireDeadHosts() { + private def expireDeadHosts() { logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") val now = System.currentTimeMillis() val minSeenTime = now - slaveTimeout - val toRemove = new HashSet[BlockManagerId] + val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { if (info.lastSeenMs < minSeenTime) { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + @@ -129,31 +164,26 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove.foreach(removeBlockManager) } - def removeExecutor(execId: String) { + private def removeExecutor(execId: String) { logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - sender ! true } - def heartBeat(blockManagerId: BlockManagerId) { + private def heartBeat(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "" && !isLocal) { - sender ! true - } else { - sender ! false - } + blockManagerId.executorId == "" && !isLocal } else { blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true + true } } // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - private def removeBlock(blockId: String) { - val block = blockLocations.get(blockId) - if (block != null) { - block._2.foreach { blockManagerId: BlockManagerId => + private def removeBlockFromWorkers(blockId: String) { + val locations = blockLocations.get(blockId) + if (locations != null) { + locations.foreach { blockManagerId: BlockManagerId => val blockManager = blockManagerInfo.get(blockManagerId) if (blockManager.isDefined) { // Remove the block from the slave's BlockManager. @@ -163,23 +193,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } } - sender ! true } // Return a map from the block manager id to max memory and remaining memory. - private def getMemoryStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => + private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { + blockManagerInfo.map { case(blockManagerId, info) => (blockManagerId, (info.maxMem, info.remainingMem)) }.toMap - sender ! res } - private def getStorageStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => + private def storageStatus: Array[StorageStatus] = { + blockManagerInfo.map { case(blockManagerId, info) => import collection.JavaConverters._ StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) - } - sender ! res + }.toArray } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { @@ -188,7 +215,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } else if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => - // A block manager of the same host name already exists + // A block manager of the same executor already exists. + // This should never happen. Let's just quit. logError("Got two different block manager registrations on " + id.executorId) System.exit(1) case None => @@ -197,7 +225,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveActor) } - sender ! true } private def updateBlockInfo( @@ -226,12 +253,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - var locations: HashSet[BlockManagerId] = null + var locations: mutable.HashSet[BlockManagerId] = null if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId)._2 + locations = blockLocations.get(blockId) } else { - locations = new HashSet[BlockManagerId] - blockLocations.put(blockId, (storageLevel.replication, locations)) + locations = new mutable.HashSet[BlockManagerId] + blockLocations.put(blockId, locations) } if (storageLevel.isValid) { @@ -247,70 +274,24 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - if (blockLocations.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockLocations.get(blockId)._2) - sender ! res.toSeq - } else { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - sender ! res - } - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - if (blockLocations.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockLocations.get(blockId)._2) - return res.toSeq - } else { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - sender ! res.toSeq + private def getLocations(blockId: String): Seq[BlockManagerId] = { + if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) - } - sender ! res.toSeq + private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + blockIds.map(blockId => getLocations(blockId)) } - private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { + val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray val selfIndex = peers.indexOf(blockManagerId) if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") + throw new SparkException("Self index for " + blockManagerId + " not found") } // Note that this logic will select the same node multiple times if there aren't enough peers - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) - } - sender ! res.toSeq + Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } } @@ -384,6 +365,13 @@ object BlockManagerMasterActor { } } + def removeBlock(blockId: String) { + if (_blocks.containsKey(blockId)) { + _remainingMem += _blocks.get(blockId).memSize + _blocks.remove(blockId) + } + } + def remainingMem: Long = _remainingMem def lastSeenMs: Long = _lastSeenMs diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index cff48d9909..88268fd41b 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -16,6 +16,12 @@ sealed trait ToBlockManagerSlave private[spark] case class RemoveBlock(blockId: String) extends ToBlockManagerSlave +// Remove all blocks belonging to a specific RDD. +private[spark] case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + +// Remove the meta data for a RDD. This is only sent to the master by the master. +private[spark] case class RemoveRddMetaData(rddId: Int, numBlocks: Int) extends ToBlockManagerMaster + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala index f570cdc52d..b264d1deb5 100644 --- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala @@ -11,6 +11,12 @@ import spark.{Logging, SparkException, Utils} */ class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { override def receive = { - case RemoveBlock(blockId) => blockManager.removeBlock(blockId) + + case RemoveBlock(blockId) => + blockManager.removeBlock(blockId) + + case RemoveRdd(rddId) => + val numBlocksRemoved = blockManager.removeRdd(rddId) + sender ! numBlocksRemoved } } diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index 15225f93a6..3057ade233 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -2,13 +2,7 @@ package spark.storage import java.nio.ByteBuffer -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.util.Random - -import spark.{Logging, Utils, SparkEnv} +import spark.{Logging, Utils} import spark.network._ /** @@ -88,8 +82,6 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null - private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 - private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 initLogging() diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index bff2475686..b9d5f9668e 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -15,10 +15,10 @@ import org.scalatest.time.SpanSugar._ import spark.JavaSerializer import spark.KryoSerializer import spark.SizeEstimator -import spark.Utils import spark.util.AkkaUtils import spark.util.ByteBufferInputStream + class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { var store: BlockManager = null var store2: BlockManager = null @@ -124,7 +124,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory assert(store.getSingle("a1") != None, "a1 was not in store") @@ -170,7 +170,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) // Checking whether blocks are in memory and memory size val memStatus = master.getMemoryStatus.head._2 @@ -218,7 +218,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) - master.removeRdd(0) + master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { store.getSingle("rdd_0_0") should be (None) @@ -232,6 +232,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store.getSingle("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } + + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = true) + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 } test("reregistration on heart beat") { @@ -262,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") @@ -280,7 +288,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { override def run() { - store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) } } val t2 = new Thread { @@ -490,9 +498,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, true) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) assert(store.get("list3") != None, "list3 was not in store") @@ -501,7 +509,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.get("list2") != None, "list2 was not in store") assert(store.get("list2").get.size == 2) // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list1") != None, "list1 was not in store") assert(store.get("list1").get.size == 2) assert(store.get("list2") != None, "list2 was not in store") @@ -516,9 +524,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val list3 = List(new Array[Byte](200), new Array[Byte](200)) val list4 = List(new Array[Byte](200), new Array[Byte](200)) // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, true) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, true) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) // At this point LRU should not kick in because list3 is only on disk assert(store.get("list1") != None, "list2 was not in store") assert(store.get("list1").get.size === 2) @@ -533,7 +541,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.get("list3") != None, "list1 was not in store") assert(store.get("list3").get.size === 2) // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, true) + store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2") != None, "list3 was not in store") assert(store.get("list2").get.size === 2) -- cgit v1.2.3 From de1167bf2c32d52c865a4a0c7213b665ebd61f93 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 31 May 2013 15:54:57 -0700 Subject: Incorporated Charles' feedback to put rdd metadata removal in BlockManagerMasterActor. --- .../scala/spark/storage/BlockManagerMaster.scala | 21 +++------------------ .../spark/storage/BlockManagerMasterActor.scala | 22 ++++++++++------------ .../scala/spark/storage/BlockManagerMessages.scala | 3 --- 3 files changed, 13 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 7099e40618..58888b1ebb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -84,24 +84,9 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi * Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - // The logic to remove an RDD is somewhat complicated: - // 1. Send BlockManagerMasterActor a RemoveRdd message. - // 2. Upon receiving the RemoveRdd message, BlockManagerMasterActor will forward the message - // to all workers to remove blocks belonging to the RDD, and return a Future for the results. - // 3. The Future is sent back here, and on successful completion of the Future, this function - // sends a RemoveRddMetaData message to BlockManagerMasterActor. - // 4. Upon receiving the RemoveRddMetaData message, BlockManagerMasterActor will delete the meta - // data for the given RDD. - // - // The reason we are doing it this way is to reduce the amount of messages the driver sends. - // The number of messages that need to be sent is only the number of workers the cluster has, - // rather than the number of blocks in the cluster. Note that we can further reduce the number - // of messages by tracking for a given RDD, where are its blocks. Then we can send only to the - // workers that have the given RDD. But this remains future work. val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onComplete { - case Left(throwable) => logError("Failed to remove RDD " + rddId, throwable) - case Right(numBlocks) => tell(RemoveRddMetaData(rddId, numBlocks.sum)) + future onFailure { + case e: Throwable => logError("Failed to remove RDD " + rddId, e) } if (blocking) { Await.result(future, timeout) @@ -156,7 +141,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi val future = driverActor.ask(message)(timeout) val result = Await.result(future, timeout) if (result == null) { - throw new Exception("BlockManagerMaster returned null") + throw new SparkException("BlockManagerMaster returned null") } return result.asInstanceOf[T] } catch { diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 00aa97bf78..2d05e0ccf1 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -79,10 +79,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case RemoveRdd(rddId) => sender ! removeRdd(rddId) - case RemoveRddMetaData(rddId, numBlocks) => - removeRddMetaData(rddId, numBlocks) - sender ! true - case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -110,15 +106,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(RemoveRdd(rddId))(akkaTimeout).mapTo[Int] - }.toSeq) - } + // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // from the slaves. - private def removeRddMetaData(rddId: Int, numBlocks: Int) { val prefix = "rdd_" + rddId + "_" // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. @@ -128,6 +118,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) blockLocations.remove(blockId) } + + // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. + // The dispatcher is used as an implicit argument into the Future sequence construction. + import context.dispatcher + val removeMsg = RemoveRdd(rddId) + Future.sequence(blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq) } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 88268fd41b..0010726c8d 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -19,9 +19,6 @@ case class RemoveBlock(blockId: String) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. private[spark] case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave -// Remove the meta data for a RDD. This is only sent to the master by the master. -private[spark] case class RemoveRddMetaData(rddId: Int, numBlocks: Int) extends ToBlockManagerMaster - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. -- cgit v1.2.3 From 9f84315c055d7a53da8787eb26b336726fc33e8a Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sat, 1 Jun 2013 00:26:10 +0000 Subject: enhance pipe to support what we can do in hadoop streaming --- core/src/main/scala/spark/RDD.scala | 18 ++++++++++++++++++ core/src/main/scala/spark/rdd/PipedRDD.scala | 25 +++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dde131696f..5a41db23c2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -348,17 +348,35 @@ abstract class RDD[T: ClassManifest]( */ def pipe(command: String): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, transform, arguments) + /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, transform, arguments) + /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: Seq[String], env: Map[String, String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = + new PipedRDD(this, command, env, transform, arguments) + /** * Return a new RDD by applying a function to each partition of this RDD. */ diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 962a1b21ad..969404c95f 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -18,14 +18,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext} class PipedRDD[T: ClassManifest]( prev: RDD[T], command: Seq[String], - envVars: Map[String, String]) + envVars: Map[String, String], + transform: (T, String => Unit) => Any, + arguments: Seq[String] + ) extends RDD[String](prev) { + def this(prev: RDD[T], command: Seq[String], envVars : Map[String, String]) = this(prev, command, envVars, null, null) def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) + def this(prev: RDD[T], command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, command, Map(), transform, arguments) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, PipedRDD.tokenize(command), Map(), transform, arguments) + override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -52,8 +59,22 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + + // input the arguments firstly + if ( arguments != null) { + for (elem <- arguments) { + out.println(elem) + } + // ^A \n as the marker of the end of the arguments + out.println("\u0001") + } for (elem <- firstParent[T].iterator(split, context)) { - out.println(elem) + if (transform != null) { + transform(elem, out.println(_)) + } + else { + out.println(elem) + } } out.close() } -- cgit v1.2.3 From 91aca9224936da84b16ea789cb81914579a0db03 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:21:38 -0700 Subject: Another round of Netty fixes. 1. Avoid race condition between stop and copier completion 2. Handle socket exceptions by reporting them and filling in a failed FetchResult --- .../main/java/spark/network/netty/FileClient.java | 24 +++------ .../spark/network/netty/FileClientHandler.java | 8 +++ .../scala/spark/network/netty/ShuffleCopier.scala | 62 ++++++++++++++-------- .../scala/spark/storage/BlockFetcherIterator.scala | 9 ++-- 4 files changed, 58 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 3a62dacbc8..9c9b976ebe 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -8,9 +8,12 @@ import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; class FileClient { + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; @@ -25,25 +28,10 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } - public static final class ChannelCloseListener implements ChannelFutureListener { - private FileClient fc = null; - - public ChannelCloseListener(FileClient fc){ - this.fc = fc; - } - - @Override - public void operationComplete(ChannelFuture future) { - if (fc.bootstrap!=null){ - fc.bootstrap.shutdown(); - fc.bootstrap = null; - } - } - } - public void connect(String host, int port) { try { // Start the connection attempt. @@ -58,8 +46,8 @@ class FileClient { public void waitForClose() { try { channel.closeFuture().sync(); - } catch (InterruptedException e){ - e.printStackTrace(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); } } diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java index 2069dee5ca..9fc9449827 100644 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/spark/network/netty/FileClientHandler.java @@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { private FileHeader currentHeader = null; + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { @@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { // get file if(in.readableBytes() >= currentHeader.fileLen()) { handle(ctx, in, currentHeader); + handlerCalled = true; currentHeader = null; ctx.close(); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index a91f5a886d..8ec46d42fa 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -9,19 +9,35 @@ import io.netty.util.CharsetUtil import spark.Logging import spark.network.ConnectionManagerId +import scala.collection.JavaConverters._ + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, blockId: String, + def getBlock(host: String, port: Int, blockId: String, resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) - fc.init() - fc.connect(cmId.host, cmId.port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() + try { + fc.init() + fc.connect(host, port) + fc.sendRequest(blockId) + fc.waitForClose() + fc.close() + } catch { + // Handle any socket-related exceptions in FileClient + case e: Exception => { + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + + " failed", e) + handler.handleError(blockId) + } + } + } + + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { + getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, @@ -44,20 +60,18 @@ private[spark] object ShuffleCopier extends Logging { logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + override def handleError(blockId: String) { + if (!isComplete) { + resultCollectCallBack(blockId, -1, null) + } + } } - def runGetBlock(host:String, port:Int, file:String){ - val handler = new ShuffleClientHandler(echoResultCollectCallBack) - val fc = new FileClient(handler) - fc.init(); - fc.connect(host, port) - fc.sendRequest(file) - fc.waitForClose(); - fc.close() + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + if (size != -1) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") + } } def main(args: Array[String]) { @@ -71,14 +85,16 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0, threads)) { - val runnable = new Runnable() { + val tasks = (for (i <- Range(0, threads)) yield { + Executors.callable(new Runnable() { def run() { - runGetBlock(host, port, file) + val copier = new ShuffleCopier() + copier.getBlock(host, port, file, echoResultCollectCallBack) } - } - copiers.execute(runnable) - } + }) + }).asJava + copiers.invokeAll(tasks) copiers.shutdown + System.exit(0) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 1d69d658f7..fac416a5b3 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -265,7 +265,7 @@ object BlockFetcherIterator { }).toList } - //keep this to interrupt the threads when necessary + // keep this to interrupt the threads when necessary private def stopCopiers() { for (copier <- copiers) { copier.interrupt() @@ -312,9 +312,10 @@ object BlockFetcherIterator { resultsGotten += 1 val result = results.take() // if all the results has been retrieved, shutdown the copiers - if (resultsGotten == _totalBlocks && copiers != null) { - stopCopiers() - } + // NO need to stop the copiers if we got all the blocks ? + // if (resultsGotten == _totalBlocks && copiers != null) { + // stopCopiers() + // } (result.blockId, if (result.failed) None else Some(result.deserialize())) } } -- cgit v1.2.3 From 038cfc1a9acb32f8c17d883ea64f8cbb324ed82c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 31 May 2013 23:32:18 -0700 Subject: Make connect timeout configurable --- core/src/main/java/spark/network/netty/FileClient.java | 6 ++++-- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 9c9b976ebe..517772202f 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -17,9 +17,11 @@ class FileClient { private FileClientHandler handler = null; private Channel channel = null; private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min - public FileClient(FileClientHandler handler) { + public FileClient(FileClientHandler handler, int connectTimeout) { this.handler = handler; + this.connectTimeout = connectTimeout; } public void init() { @@ -28,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index 8ec46d42fa..afb2cdbb3a 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,7 +18,8 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler) + val fc = new FileClient(handler, + System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) try { fc.init() fc.connect(host, port) -- cgit v1.2.3 From 3be7bdcefda13d67633f9b9f6d901722fd5649de Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sat, 1 Jun 2013 19:32:17 +0530 Subject: Adding example to make Spark RDD from Cassandra --- .../main/scala/spark/examples/CassandraTest.scala | 154 +++++++++++++++++++++ project/SparkBuild.scala | 4 +- 2 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 examples/src/main/scala/spark/examples/CassandraTest.scala diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala new file mode 100644 index 0000000000..790b24e6f3 --- /dev/null +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -0,0 +1,154 @@ +package spark.examples + +import org.apache.hadoop.mapreduce.Job +import org.apache.cassandra.hadoop.{ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.thrift.{IndexExpression, SliceRange, SlicePredicate} +import spark.{RDD, SparkContext} +import SparkContext._ +import java.nio.ByteBuffer +import java.util.SortedMap +import org.apache.cassandra.db.IColumn +import org.apache.cassandra.utils.ByteBufferUtil +import scala.collection.JavaConversions._ + + +/* + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra support for Hadoop. + * + * To run this example, run this file with the following command params - + * + * + * So if you want to run this on localhost this will be, + * local[3] localhost 9160 + * + * The example makes some assumptions: + * 1. You have already created a keyspace called casDemo and it has a column family named Words + * 2. There are column family has a column named "para" which has test content. + * + * You can create the content by running the following script at the bottom of this file with cassandra-cli. + * + */ +object CassandraTest { + def main(args: Array[String]) { + + //Get a SparkContext + val sc = new SparkContext(args(0), "casDemo") + + //Build the job configuration with ConfigHelper provided by Cassandra + val job = new Job() + job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) + + ConfigHelper.setInputInitialAddress(job.getConfiguration(), args(1)) + + ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) + + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + + val predicate = new SlicePredicate() + val sliceRange = new SliceRange() + sliceRange.setStart(Array.empty[Byte]) + sliceRange.setFinish(Array.empty[Byte]) + predicate.setSlice_range(sliceRange) + ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) + + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + //Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + classOf[ColumnFamilyInputFormat], + classOf[ByteBuffer], + classOf[SortedMap[ByteBuffer, IColumn]]) + + // Let us first get all the paragraphs from the retrieved rows + val paraRdd = casRdd flatMap { + case (key, value) => { + value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) + } + } + + //Lets get the word count in paras + val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) + + counts.collect() foreach { + case(word, count) => println(word + ":" + count) + } + } +} + +/* +create keyspace casDemo; +use casDemo; + +create column family Words with comparator = UTF8Type; +update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; + +assume Words keys as utf8; + +set Words['3musk001']['book'] = 'The Three Musketeers'; +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market town of + Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to + be in as perfect a state of revolution as if the Huguenots had just made + a second La Rochelle of it. Many citizens, seeing the women flying + toward the High Street, leaving their children crying at the open doors, + hastened to don the cuirass, and supporting their somewhat uncertain + courage with a musket or a partisan, directed their steps toward the + hostelry of the Jolly Miller, before which was gathered, increasing + every minute, a compact group, vociferous and full of curiosity.'; + +set Words['3musk002']['book'] = 'The Three Musketeers'; +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without some city + or other registering in its archives an event of this kind. There were + nobles, who made war against each other; there was the king, who made + war against the cardinal; there was Spain, which made war against the + king. Then, in addition to these concealed or public, secret or open + wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, + who made war upon everybody. The citizens always took up arms readily + against thieves, wolves or scoundrels, often against nobles or + Huguenots, sometimes against the king, but never against cardinal or + Spain. It resulted, then, from this habit that on the said first Monday + of April, 1625, the citizens, on hearing the clamor, and seeing neither + the red-and-yellow standard nor the livery of the Duc de Richelieu, + rushed toward the hostel of the Jolly Miller. When arrived there, the + cause of the hubbub was apparent to all'; + +set Words['3musk003']['book'] = 'The Three Musketeers'; +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however large + the sum may be; but you ought also to endeavor to perfect yourself in + the exercises becoming a gentleman. I will write a letter today to the + Director of the Royal Academy, and tomorrow he will admit you without + any expense to yourself. Do not refuse this little service. Our + best-born and richest gentlemen sometimes solicit it without being able + to obtain it. You will learn horsemanship, swordsmanship in all its + branches, and dancing. You will make some desirable acquaintances; and + from time to time you can call upon me, just to tell me how you are + getting on, and to say whether I can be of further service to you.'; + + +set Words['thelostworld001']['book'] = 'The Lost World'; +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined against the + red curtain. How beautiful she was! And yet how aloof! We had been + friends, quite good friends; but never could I get beyond the same + comradeship which I might have established with one of my + fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, + and perfectly unsexual. My instincts are all against a woman being too + frank and at her ease with me. It is no compliment to a man. Where + the real sex feeling begins, timidity and distrust are its companions, + heritage from old wicked days when love and violence went often hand in + hand. The bent head, the averted eye, the faltering voice, the wincing + figure--these, and not the unshrinking gaze and frank reply, are the + true signals of passion. Even in my short life I had learned as much + as that--or had inherited it in that race memory which we call instinct.'; + +set Words['thelostworld002']['book'] = 'The Lost World'; +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, red-headed news + editor, and I rather hoped that he liked me. Of course, Beaumont was + the real boss; but he lived in the rarefied atmosphere of some Olympian + height from which he could distinguish nothing smaller than an + international crisis or a split in the Cabinet. Sometimes we saw him + passing in lonely majesty to his inner sanctum, with his eyes staring + vaguely and his mind hovering over the Balkans or the Persian Gulf. He + was above and beyond us. But McArdle was his first lieutenant, and it + was he that we knew. The old man nodded as I entered the room, and he + pushed his spectacles far up on his bald forehead.'; + +*/ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 0ea23b446f..5152b7b79b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -201,8 +201,8 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11") - ) + libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11", + "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru:1.3") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From 81c2adc15c9e232846d4ad0adf14d007039409fa Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sun, 2 Jun 2013 12:51:15 +0530 Subject: Removing infix call --- examples/src/main/scala/spark/examples/CassandraTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 790b24e6f3..49b940d8a7 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -60,7 +60,7 @@ object CassandraTest { classOf[SortedMap[ByteBuffer, IColumn]]) // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd flatMap { + val paraRdd = casRdd.flatMap { case (key, value) => { value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) } @@ -69,8 +69,8 @@ object CassandraTest { //Lets get the word count in paras val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) - counts.collect() foreach { - case(word, count) => println(word + ":" + count) + counts.collect().foreach { + case (word, count) => println(word + ":" + count) } } } -- cgit v1.2.3 From 6d8423fd1b490d541f0ea379068b8954002d624f Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Sun, 2 Jun 2013 13:03:45 +0530 Subject: Adding deps to examples/pom.xml Fixing exclusion in examples deps in SparkBuild.scala --- examples/pom.xml | 35 +++++++++++++++++++++++++++++++++++ project/SparkBuild.scala | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/examples/pom.xml b/examples/pom.xml index c42d2bcdb9..b4c5251d68 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,41 @@ scalacheck_${scala.version} test + + org.apache.cassandra + cassandra-all + 1.2.5 + + + com.google.guava + guava + + + com.googlecode.concurrentlinkedhashmap + concurrentlinkedhashmap-lru + + + com.ning + compress-lzf + + + io.netty + netty + + + jline + jline + + + log4j + log4j + + + org.apache.cassandra.deps + avro + + + target/scala-${scala.version}/classes diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5152b7b79b..7f3e223c2e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -202,7 +202,7 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11", - "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru:1.3") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) + "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") exclude("com.ning","compress-lzf") exclude("io.netty","netty") exclude("jline","jline") exclude("log4j","log4j") exclude("org.apache.cassandra.deps", "avro"))) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") -- cgit v1.2.3 From 4a9913d66a61ac9ef9cab0e08f6151dc2624fd11 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sun, 2 Jun 2013 23:21:09 +0000 Subject: add ut for pipe enhancement --- core/src/test/scala/spark/PipedRDDSuite.scala | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index a6344edf8f..ee55952a94 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -19,6 +19,37 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(3) === "4") } + test("advanced pipe") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat"), (i:Int, f: String=> Unit) => f(i + "_"), Array("0")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)).pipe(Seq("cat"), (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}, Array("0")).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + test("pipe with env variable") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) -- cgit v1.2.3 From 606bb1b450064a2b909e4275ce45325dbbef4eca Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Fri, 31 May 2013 15:40:41 +0800 Subject: Fix schedulingAlgorithm bugs for unit test --- .../spark/scheduler/cluster/SchedulingAlgorithm.scala | 17 +++++++++++++---- .../scala/spark/scheduler/ClusterSchedulerSuite.scala | 9 ++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index a5d6285c99..13120edf63 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -40,15 +40,24 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true + var compare:Int = 0 if (s1Needy && !s2Needy) { - res = true + return true } else if (!s1Needy && s2Needy) { - res = false + return false } else if (s1Needy && s2Needy) { - res = minShareRatio1 <= minShareRatio2 + compare = minShareRatio1.compareTo(minShareRatio2) + } else { + compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) + } + + if (compare < 0) { + res = true + } else if (compare > 0) { + res = false } else { - res = taskToWeightRatio1 <= taskToWeightRatio2 + return s1.name < s2.name } return res } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index a39418b716..c861597c6b 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -88,7 +88,7 @@ class DummyTask(stageId: Int) extends Task[Int](stageId) } } -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext { +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = { new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet) @@ -96,8 +96,11 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext { def resourceOffer(rootPool: Pool): Int = { val taskSetQueue = rootPool.getSortedTaskSetQueue() - for (taskSet <- taskSetQueue) - { + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { taskSet.slaveOffer("execId_1", "hostname_1", 1) match { case Some(task) => return taskSet.stageId -- cgit v1.2.3 From 56c64c403383e90a5fd33b6a1f72527377d9bee0 Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Mon, 3 Jun 2013 12:48:35 +0530 Subject: A better way to read column value if you are sure the column exists in every row. --- examples/src/main/scala/spark/examples/CassandraTest.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 49b940d8a7..6b9fd502e2 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -10,6 +10,8 @@ import java.util.SortedMap import org.apache.cassandra.db.IColumn import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /* @@ -60,9 +62,9 @@ object CassandraTest { classOf[SortedMap[ByteBuffer, IColumn]]) // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd.flatMap { + val paraRdd = casRdd.map { case (key, value) => { - value.filter(v => ByteBufferUtil.string(v._1).compareTo("para") == 0).map(v => ByteBufferUtil.string(v._2.value())) + ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) } } -- cgit v1.2.3 From b104c7f5c7e2b173fe1b10035efbc00e43df13ec Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Mon, 3 Jun 2013 15:15:52 +0530 Subject: Example to write the output to cassandra --- .../main/scala/spark/examples/CassandraTest.scala | 48 +++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 6b9fd502e2..2cc62b9fe9 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -1,17 +1,16 @@ package spark.examples import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.{ConfigHelper, ColumnFamilyInputFormat} -import org.apache.cassandra.thrift.{IndexExpression, SliceRange, SlicePredicate} +import org.apache.cassandra.hadoop.{ColumnFamilyOutputFormat, ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.thrift._ import spark.{RDD, SparkContext} -import SparkContext._ +import spark.SparkContext._ import java.nio.ByteBuffer import java.util.SortedMap import org.apache.cassandra.db.IColumn import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; + /* @@ -44,8 +43,15 @@ object CassandraTest { ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), args(1)) + + ConfigHelper.setOutputRpcPort(job.getConfiguration(), args(2)) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") + + val predicate = new SlicePredicate() val sliceRange = new SliceRange() sliceRange.setStart(Array.empty[Byte]) @@ -55,6 +61,8 @@ object CassandraTest { ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + //Make a new Hadoop RDD val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[ColumnFamilyInputFormat], @@ -74,6 +82,33 @@ object CassandraTest { counts.collect().foreach { case (word, count) => println(word + ":" + count) } + + counts.map { + case (word, count) => { + val colWord = new org.apache.cassandra.thrift.Column() + colWord.setName(ByteBufferUtil.bytes("word")) + colWord.setValue(ByteBufferUtil.bytes(word)) + colWord.setTimestamp(System.currentTimeMillis) + + val colCount = new org.apache.cassandra.thrift.Column() + colCount.setName(ByteBufferUtil.bytes("wcount")) + colCount.setValue(ByteBufferUtil.bytes(count.toLong)) + colCount.setTimestamp(System.currentTimeMillis) + + + val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) + + val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(0).column_or_supercolumn.setColumn(colWord) + + mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(1).column_or_supercolumn.setColumn(colCount) + (outputkey, mutations) + } + }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], + classOf[ColumnFamilyOutputFormat], job.getConfiguration) + } } @@ -81,6 +116,9 @@ object CassandraTest { create keyspace casDemo; use casDemo; +create column family WordCount with comparator = UTF8Type; +update column family WordCount with column_metadata = [{column_name: word, validation_class: UTF8Type}, {column_name: wcount, validation_class: LongType}]; + create column family Words with comparator = UTF8Type; update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; -- cgit v1.2.3 From a058b0acf3e5ae41e64640feeace3d4e32f47401 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:10:00 -0700 Subject: Delete a file for a block if it already exists. --- core/src/main/scala/spark/storage/DiskStore.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..2be5d01e31 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,9 +195,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) + var file = getFile(blockId) if (!allowAppendExisting && file.exists()) { - throw new Exception("File for block " + blockId + " already exists on disk: " + file) + // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task + // was rescheduled on the same machine as the old task ? + logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") + file.delete() + // Reopen the file + file = getFile(blockId) + // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } -- cgit v1.2.3 From cd347f547a9a9b7bdd0d3f4734ae5c13be54f75d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:27:51 -0700 Subject: Reuse the file object as it is valid after delete --- core/src/main/scala/spark/storage/DiskStore.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 2be5d01e31..e51d258a21 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -201,8 +201,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // was rescheduled on the same machine as the old task ? logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // Reopen the file - file = getFile(blockId) // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file -- cgit v1.2.3 From 96943a1cc054d7cf80eb8d3dfc7fb19ce48d3c0a Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 3 Jun 2013 12:29:38 -0700 Subject: var to val --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index e51d258a21..cd85fa1e9d 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -195,7 +195,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - var file = getFile(blockId) + val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task // was rescheduled on the same machine as the old task ? -- cgit v1.2.3 From d1286231e0db15e480bd7d6a600b419db3391b27 Mon Sep 17 00:00:00 2001 From: Konstantin Boudnik Date: Wed, 29 May 2013 20:14:59 -0700 Subject: Sometime Maven build runs out of PermGen space. --- pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pom.xml b/pom.xml index 6ee64d07c2..ce77ba37c6 100644 --- a/pom.xml +++ b/pom.xml @@ -59,6 +59,9 @@ 1.6.1 4.1.2 1.2.17 + + 0m + 512m @@ -392,6 +395,10 @@ -Xms64m -Xmx1024m + -XX:PermSize + ${PermGen} + -XX:MaxPermSize + ${MaxPermGen} -source -- cgit v1.2.3 From 8bd4e1210422d9985e6105fd9529e813fe45c14e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 4 Jun 2013 18:14:24 -0400 Subject: Bump akka and blockmanager timeouts to 60 seconds --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- core/src/main/scala/spark/util/AkkaUtils.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 09572b19db..150c98f57c 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -950,7 +950,7 @@ object BlockManager extends Logging { } def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerHeartBeatMs", "5000").toLong + System.getProperty("spark.storage.blockManagerHeartBeatMs", "60000").toLong def getDisableHeartBeatsForTesting: Boolean = System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 9fb7e001ba..def993236b 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -29,7 +29,7 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt + val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. -- cgit v1.2.3 From 061fd3ae369e744f076e21044de26a00982a408f Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 4 Jun 2013 18:50:15 -0400 Subject: Fixing bug in BlockManager timeout --- core/src/main/scala/spark/storage/BlockManager.scala | 2 +- core/src/main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 150c98f57c..65c789ea8f 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -950,7 +950,7 @@ object BlockManager extends Logging { } def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerHeartBeatMs", "60000").toLong + System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 def getDisableHeartBeatsForTesting: Boolean = System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 9b64f95df8..0dcb9fb2ac 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -35,7 +35,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "5000").toLong + "60000").toLong var timeoutCheckingTask: Cancellable = null -- cgit v1.2.3 From 9d359043574f6801ba15ec9d016eba0f00ac2349 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Tue, 4 Jun 2013 22:12:47 -0700 Subject: In the current code, when both partitions happen to have zero-length, the return mean will be NaN. Consequently, the result of mean after reducing over all partitions will also be NaN, which is not correct if there are partitions with non-zero length. This patch fixes this issue. --- core/src/main/scala/spark/util/StatCounter.scala | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala index 5f80180339..2b980340b7 100644 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (other == this) { merge(other.copy()) // Avoid overwriting fields in a weird order } else { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - this + this } } -- cgit v1.2.3 From c851957fe4798d5dfb8deba7bf79a035a0543c74 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 14:28:38 -0700 Subject: Don't write zero block files with java serializer --- .../scala/spark/storage/BlockFetcherIterator.scala | 5 ++- core/src/main/scala/spark/storage/DiskStore.scala | 46 ++++++++++++++-------- .../scala/spark/storage/ShuffleBlockManager.scala | 2 +- core/src/test/scala/spark/ShuffleSuite.scala | 26 ++++++++++++ 4 files changed, 61 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index fac416a5b3..843069239c 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -71,6 +71,7 @@ object BlockFetcherIterator { logDebug("Getting " + _totalBlocks + " blocks") protected var startTime = System.currentTimeMillis protected val localBlockIds = new ArrayBuffer[String]() + protected val localNonZeroBlocks = new ArrayBuffer[String]() protected val remoteBlockIds = new HashSet[String]() // A queue to hold our results. @@ -129,6 +130,8 @@ object BlockFetcherIterator { for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { localBlockIds ++= blockInfos.map(_._1) + localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) } else { remoteBlockIds ++= blockInfos.map(_._1) // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them @@ -172,7 +175,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localBlockIds) { + for (id <- localNonZeroBlocks) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index cd85fa1e9d..c1cff25552 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private var bs: OutputStream = null private var objOut: SerializationStream = null private var lastValidPosition = 0L + private var initialized = false override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) objOut = serializer.newInstance().serializeStream(bs) + initialized = true this } override def close() { - objOut.close() - bs.close() - channel = null - bs = null - objOut = null + if (initialized) { + objOut.close() + bs.close() + channel = null + bs = null + objOut = null + } // Invoke the close callback handler. super.close() } @@ -59,23 +63,33 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } } override def revertPartialWrites() { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } } override def write(value: Any) { + if (!initialized) { + open() + } objOut.writeObject(value) } diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala index 49eabfb0d2..44638e0c2d 100644 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala @@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open() + blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) } new ShuffleWriterGroup(mapId, writers) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index b967016cf7..33b02fff80 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -367,6 +367,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite { -- cgit v1.2.3 From cb2f5046ee99582a5038a78478c23468b14c134e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 5 Jun 2013 15:09:02 -0700 Subject: Pass in bufferSize to BufferedOutputStream --- core/src/main/scala/spark/storage/DiskStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c1cff25552..0af6e4a359 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -40,7 +40,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def open(): DiskBlockObjectWriter = { val fos = new FileOutputStream(f, true) channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos)) + bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true this -- cgit v1.2.3 From e179ff8a32fc08cc308dc99bac2527d350d0d970 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Wed, 5 Jun 2013 22:41:05 +0000 Subject: update according to comments --- core/src/main/scala/spark/RDD.scala | 89 +++++++++++++++++++++++---- core/src/main/scala/spark/rdd/PipedRDD.scala | 33 +++++----- core/src/test/scala/spark/PipedRDDSuite.scala | 7 ++- 3 files changed, 99 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 5a41db23c2..a1c9604324 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import spark.broadcast.Broadcast import spark.Partitioner._ import spark.partial.BoundedDouble import spark.partial.CountEvaluator @@ -351,31 +352,93 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, transform, arguments) + def pipe(command: String, env: Map[String, String]): RDD[String] = + new PipedRDD(this, command, env) /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + env: Map[String, String], + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U], + delimiter: String): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, delimiter) /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, transform, arguments) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U]): RDD[String] = + new PipedRDD(this, command, Map[String, String](), transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: String, + env: Map[String, String], + transform: (T,String => Unit) => Any, + pipeContext: Broadcast[U]): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: Seq[String], env: Map[String, String], transform: (T,String => Unit) => Any, arguments: Seq[String]): RDD[String] = - new PipedRDD(this, command, env, transform, arguments) + * How each record in RDD is outputed to the process can be controled by providing a + * function trasnform(T, outputFunction: String => Unit). transform() will be called with + * the currnet record in RDD as the 1st parameter, and the function to output the record to + * the external process (like out.println()) as the 2nd parameter. + * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the records: + * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} + * pipeContext can be used to transfer additional context data to the external process + * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to + * external process with "^A" as the delimiter in the end of context data. Delimiter can also + * be customized by the last parameter delimiter. + */ + def pipe[U<: Seq[String]]( + command: Seq[String], + env: Map[String, String] = Map(), + transform: (T,String => Unit) => Any = null, + pipeContext: Broadcast[U] = null, + delimiter: String = "\u0001"): RDD[String] = + new PipedRDD(this, command, env, transform, pipeContext, delimiter) /** * Return a new RDD by applying a function to each partition of this RDD. diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 969404c95f..d58aaae709 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -9,29 +9,33 @@ import scala.collection.mutable.ArrayBuffer import scala.io.Source import spark.{RDD, SparkEnv, Partition, TaskContext} +import spark.broadcast.Broadcast /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest]( +class PipedRDD[T: ClassManifest, U <: Seq[String]]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], transform: (T, String => Unit) => Any, - arguments: Seq[String] + pipeContext: Broadcast[U], + delimiter: String ) extends RDD[String](prev) { - def this(prev: RDD[T], command: Seq[String], envVars : Map[String, String]) = this(prev, command, envVars, null, null) - def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) - def this(prev: RDD[T], command: Seq[String], transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, command, Map(), transform, arguments) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - def this(prev: RDD[T], command: String, transform: (T,String => Unit) => Any, arguments: Seq[String]) = this(prev, PipedRDD.tokenize(command), Map(), transform, arguments) + def this( + prev: RDD[T], + command: String, + envVars: Map[String, String] = Map(), + transform: (T, String => Unit) => Any = null, + pipeContext: Broadcast[U] = null, + delimiter: String = "\u0001") = + this(prev, PipedRDD.tokenize(command), envVars, transform, pipeContext, delimiter) override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -60,19 +64,18 @@ class PipedRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - // input the arguments firstly - if ( arguments != null) { - for (elem <- arguments) { + // input the pipeContext firstly + if ( pipeContext != null) { + for (elem <- pipeContext.value) { out.println(elem) } - // ^A \n as the marker of the end of the arguments - out.println("\u0001") + // delimiter\n as the marker of the end of the pipeContext + out.println(delimiter) } for (elem <- firstParent[T].iterator(split, context)) { if (transform != null) { transform(elem, out.println(_)) - } - else { + } else { out.println(elem) } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index ee55952a94..d2852867de 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -23,7 +23,8 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat"), (i:Int, f: String=> Unit) => f(i + "_"), Array("0")) + val piped = nums.pipe(Seq("cat"), Map[String, String](), + (i:Int, f: String=> Unit) => f(i + "_"), sc.broadcast(List("0"))) val c = piped.collect() @@ -38,7 +39,9 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { assert(c(7) === "4_") val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)).pipe(Seq("cat"), (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}, Array("0")).collect() + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), Map[String, String](), (i:Tuple2[String, Seq[String]], f: String=> Unit) => + {for (e <- i._2){ f(e + "_")}}, sc.broadcast(List("0"))).collect() assert(d.size === 8) assert(d(0) === "0") assert(d(1) === "\u0001") -- cgit v1.2.3 From ac480fd977e0de97bcfe646e39feadbd239c1c29 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 6 Jun 2013 16:34:27 -0700 Subject: Clean up variables and counters in BlockFetcherIterator --- .../scala/spark/storage/BlockFetcherIterator.scala | 54 +++++++++++++--------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 843069239c..bb78207c9f 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -67,12 +67,20 @@ object BlockFetcherIterator { throw new IllegalArgumentException("BlocksByAddress is null") } - protected var _totalBlocks = blocksByAddress.map(_._2.size).sum - logDebug("Getting " + _totalBlocks + " blocks") + // Total number blocks fetched (local + remote). Also number of FetchResults expected + protected var _numBlocksToFetch = 0 + protected var startTime = System.currentTimeMillis - protected val localBlockIds = new ArrayBuffer[String]() - protected val localNonZeroBlocks = new ArrayBuffer[String]() - protected val remoteBlockIds = new HashSet[String]() + + // This represents the number of local blocks, also counting zero-sized blocks + private var numLocal = 0 + // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks + protected val localBlocksToFetch = new ArrayBuffer[String]() + + // This represents the number of remote blocks, also counting zero-sized blocks + private var numRemote = 0 + // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks + protected val remoteBlocksToFetch = new HashSet[String]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -125,15 +133,15 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - localNonZeroBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - _totalBlocks -= (localBlockIds.size - localNonZeroBlocks.size) + numLocal = blockInfos.size + // Filter out zero-sized blocks + localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) + _numBlocksToFetch += localBlocksToFetch.size } else { - remoteBlockIds ++= blockInfos.map(_._1) + numRemote += blockInfos.size // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. @@ -147,10 +155,10 @@ object BlockFetcherIterator { // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) + remoteBlocksToFetch += blockId + _numBlocksToFetch += 1 curRequestSize += size - } else if (size == 0) { - _totalBlocks -= 1 - } else { + } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= minRequestSize) { @@ -166,8 +174,8 @@ object BlockFetcherIterator { } } } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") + logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + + totalBlocks + " blocks") remoteRequests } @@ -175,7 +183,7 @@ object BlockFetcherIterator { // Get the local blocks while remote blocks are being fetched. Note that it's okay to do // these all at once because they will just memory-map some files, so they won't consume // any memory that might exceed our maxBytesInFlight - for (id <- localNonZeroBlocks) { + for (id <- localBlocksToFetch) { getLocalFromDisk(id, serializer) match { case Some(iter) => { // Pass 0 as size since it's not in flight @@ -201,7 +209,7 @@ object BlockFetcherIterator { sendRequest(fetchRequests.dequeue()) } - val numGets = remoteBlockIds.size - fetchRequests.size + val numGets = remoteRequests.size - fetchRequests.size logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) // Get Local Blocks @@ -213,7 +221,7 @@ object BlockFetcherIterator { //an iterator that will read fetched blocks off the queue as they arrive. @volatile protected var resultsGotten = 0 - override def hasNext: Boolean = resultsGotten < _totalBlocks + override def hasNext: Boolean = resultsGotten < _numBlocksToFetch override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 @@ -230,9 +238,9 @@ object BlockFetcherIterator { } // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = _totalBlocks - override def numLocalBlocks: Int = localBlockIds.size - override def numRemoteBlocks: Int = remoteBlockIds.size + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote override def remoteFetchTime: Long = _remoteFetchTime override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead @@ -294,7 +302,7 @@ object BlockFetcherIterator { private var copiers: List[_ <: Thread] = null override def initialize() { - // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks + // Split Local Remote Blocks and set numBlocksToFetch val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order for (request <- Utils.randomize(remoteRequests)) { @@ -316,7 +324,7 @@ object BlockFetcherIterator { val result = results.take() // if all the results has been retrieved, shutdown the copiers // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _totalBlocks && copiers != null) { + // if (resultsGotten == _numBlocksToFetch && copiers != null) { // stopCopiers() // } (result.blockId, if (result.failed) None else Some(result.deserialize())) -- cgit v1.2.3 From c9ca0a4a588b4c7dc553b155336ae5b95aa9ddd4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 7 Jun 2013 22:40:44 -0700 Subject: Small code style fix to SchedulingAlgorithm.scala --- .../src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index 13120edf63..e071917c00 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -53,13 +53,12 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { } if (compare < 0) { - res = true + return true } else if (compare > 0) { - res = false + return false } else { return s1.name < s2.name } - return res } } -- cgit v1.2.3 From b58a29295b2e610cadf1cac44438337ce9b51537 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 7 Jun 2013 22:51:28 -0700 Subject: Small formatting and style fixes --- .../spark/scheduler/cluster/SchedulingAlgorithm.scala | 8 ++++---- core/src/main/scala/spark/storage/StorageUtils.scala | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala index e071917c00..f33310a34a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala @@ -13,11 +13,11 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { val priority1 = s1.priority val priority2 = s2.priority - var res = Math.signum(priority1 - priority2) + var res = math.signum(priority1 - priority2) if (res == 0) { val stageId1 = s1.stageId val stageId2 = s2.stageId - res = Math.signum(stageId1 - stageId2) + res = math.signum(stageId1 - stageId2) } if (res < 0) { return true @@ -35,8 +35,8 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val runningTasks2 = s2.runningTasks val s1Needy = runningTasks1 < minShare1 val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / Math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / Math.max(minShare2, 1.0).toDouble + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble var res:Boolean = true diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 81e607868d..950c0cdf35 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -55,21 +55,21 @@ object StorageUtils { }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case(rddKey, rddBlocks) => - + val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt - // Get the friendly name for the rdd, if available. + + // Get the friendly name and storage level for the RDD, if available sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) - val rddStorageLevel = r.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) + val rddName = Option(r.name).getOrElse(rddKey) + val rddStorageLevel = r.getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) } - }.flatMap(x => x).toArray + }.flatten.toArray scala.util.Sorting.quickSort(rddInfos) -- cgit v1.2.3 From 1a4d93c025e5d3679257a622f49dfaade4ac18c2 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 14:23:39 +0800 Subject: modify to pass job annotation by localProperties and use daeamon thread to do joblogger's work --- .../scala/spark/BlockStoreShuffleFetcher.scala | 1 + core/src/main/scala/spark/RDD.scala | 10 +- core/src/main/scala/spark/SparkContext.scala | 8 +- core/src/main/scala/spark/Utils.scala | 10 +- core/src/main/scala/spark/executor/Executor.scala | 1 + .../main/scala/spark/executor/TaskMetrics.scala | 12 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 8 + .../src/main/scala/spark/scheduler/JobLogger.scala | 317 +++++++++++++++++++++ .../main/scala/spark/scheduler/SparkListener.scala | 33 ++- .../scala/spark/scheduler/JobLoggerSuite.scala | 105 +++++++ .../scala/spark/scheduler/SparkListenerSuite.scala | 2 +- 11 files changed, 495 insertions(+), 12 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/JobLogger.scala create mode 100644 core/src/test/scala/spark/scheduler/JobLoggerSuite.scala diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index e1fb02157a..3239f4c385 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin CompletionIterator[(K,V), Iterator[(K,V)]](itr, { val shuffleMetrics = new ShuffleReadMetrics + shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..8c0b7ca417 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -114,6 +114,14 @@ abstract class RDD[T: ClassManifest]( this } + /**User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo._4 + + /**reset generator*/ + def setGenerator(_generator: String) = { + generator = _generator + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. This can only be used to assign a new storage level if the RDD does not @@ -788,7 +796,7 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite + private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc05d08fd6..b67a2066c8 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -48,7 +48,7 @@ import spark.scheduler.local.LocalScheduler import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} - +import spark.scheduler.JobLogger /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -510,7 +510,7 @@ class SparkContext( def addSparkListener(listener: SparkListener) { dagScheduler.sparkListeners += listener } - + addSparkListener(new JobLogger) /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. @@ -630,7 +630,7 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) @@ -713,7 +713,7 @@ class SparkContext( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { - val callSite = Utils.getSparkCallSite + val callSite = Utils.formatSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ec15326014..1630b2b4b0 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -528,7 +528,7 @@ private object Utils extends Logging { * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getSparkCallSite: String = { + def getCallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -540,6 +540,7 @@ private object Utils extends Logging { var firstUserFile = "" var firstUserLine = 0 var finished = false + var firstUserClass = "" for (el <- trace) { if (!finished) { @@ -554,13 +555,18 @@ private object Utils extends Logging { else { firstUserLine = el.getLineNumber firstUserFile = el.getFileName + firstUserClass = el.getClassName finished = true } } } - "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) + (lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } + def formatSparkCallSite = { + val callSiteInfo = getCallSiteInfo + "%s at %s:%s".format(callSiteInfo._1, callSiteInfo._2, callSiteInfo._3) + } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 890938d48b..8bebfafce4 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -104,6 +104,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() task.metrics.foreach{ m => + m.hostname = Utils.localHostName m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt } diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index a7c56c2371..26e8029365 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -1,6 +1,11 @@ package spark.executor class TaskMetrics extends Serializable { + /** + * host's name the task runs on + */ + var hostname: String = _ + /** * Time taken on the executor to deserialize this task */ @@ -33,10 +38,15 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { + /** + * Time when shuffle finishs + */ + var shuffleFinishTime: Long = _ + /** * Total number of blocks fetched in a shuffle (remote or local) */ - var totalBlocksFetched : Int = _ + var totalBlocksFetched: Int = _ /** * Number of remote blocks fetched in a shuffle diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7feeb97542..43dd7d6534 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,6 +289,7 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() + sparkListeners.foreach(_.onJobStart(job, properties)) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -311,6 +312,7 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => + sparkListeners.foreach(_.onTaskEnd(completion)) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -321,6 +323,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) + val JobCancelEvent = new SparkListenerJobCancelled("SPARKCONTEXT_SHUTDOWN") + sparkListeners.foreach(_.onJobEnd(job, JobCancelEvent)) } return true } @@ -468,6 +472,7 @@ class DAGScheduler( } } if (tasks.size > 0) { + sparkListeners.foreach(_.onStageSubmitted(stage, "TASKS_SIZE=" + tasks.size)) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) @@ -522,6 +527,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + sparkListeners.foreach(_.onJobEnd(job, SparkListenerJobSuccess)) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -665,6 +671,8 @@ class DAGScheduler( job.listener.jobFailed(new SparkException("Job failed: " + reason)) activeJobs -= job resultStageToJob -= resultStage + val jobFailedEvent = new SparkListenerJobFailed(failedStage) + sparkListeners.foreach(_.onJobEnd(job, jobFailedEvent)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala new file mode 100644 index 0000000000..f87acfd0b6 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -0,0 +1,317 @@ +package spark.scheduler + +import java.io.PrintWriter +import java.io.File +import java.io.FileNotFoundException +import java.text.SimpleDateFormat +import java.util.{Date, Properties} +import java.util.concurrent.LinkedBlockingQueue +import scala.collection.mutable.{Map, HashMap, ListBuffer} +import scala.io.Source +import spark._ +import spark.executor.TaskMetrics +import spark.scheduler.cluster.TaskInfo + +// used to record runtime information for each job, including RDD graph +// tasks' start/stop shuffle information and information from outside + +sealed trait JobLoggerEvent +case class JobLoggerOnJobStart(job: ActiveJob, properties: Properties) extends JobLoggerEvent +case class JobLoggerOnStageSubmitted(stage: Stage, info: String) extends JobLoggerEvent +case class JobLoggerOnStageCompleted(stageCompleted: StageCompleted) extends JobLoggerEvent +case class JobLoggerOnJobEnd(job: ActiveJob, event: SparkListenerEvents) extends JobLoggerEvent +case class JobLoggerOnTaskEnd(event: CompletionEvent) extends JobLoggerEvent + +class JobLogger(val logDirName: String) extends SparkListener with Logging { + private val logDir = + if (System.getenv("SPARK_LOG_DIR") != null) + System.getenv("SPARK_LOG_DIR") + else + "/tmp/spark" + private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] + private val stageIDToJobID = new HashMap[Int, Int] + private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] + private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + private val eventQueue = new LinkedBlockingQueue[JobLoggerEvent] + + createLogDir() + def this() = this(String.valueOf(System.currentTimeMillis())) + + def getLogDir = logDir + def getJobIDtoPrintWriter = jobIDToPrintWriter + def getStageIDToJobID = stageIDToJobID + def getJobIDToStages = jobIDToStages + def getEventQueue = eventQueue + + new Thread("JobLogger") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + if (event != null) { + logDebug("Got event of type " + event.getClass.getName) + event match { + case JobLoggerOnJobStart(job, info) => + processJobStartEvent(job, info) + case JobLoggerOnStageSubmitted(stage, info) => + processStageSubmittedEvent(stage, info) + case JobLoggerOnStageCompleted(stageCompleted) => + processStageCompletedEvent(stageCompleted) + case JobLoggerOnJobEnd(job, event) => + processJobEndEvent(job, event) + case JobLoggerOnTaskEnd(event) => + processTaskEndEvent(event) + case _ => + } + } + } + } + }.start() + + //create a folder for log files, the folder's name is the creation time of the jobLogger + protected def createLogDir() { + val dir = new File(logDir + "/" + logDirName + "/") + if (dir.exists()) { + return + } + if (dir.mkdirs() == false) { + logError("create log directory error:" + logDir + "/" + logDirName + "/") + } + } + + // create a log file for one job, the file name is the jobID + protected def createLogWriter(jobID: Int) { + try{ + val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) + jobIDToPrintWriter += (jobID -> fileWriter) + } catch { + case e: FileNotFoundException => e.printStackTrace() + } + } + + // close log file for one job, and clean the stage relationship in stageIDToJobID + protected def closeLogWriter(jobID: Int) = + jobIDToPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + jobIDToStages.get(jobID).foreach(_.foreach{ stage => + stageIDToJobID -= stage.id + }) + jobIDToPrintWriter -= jobID + jobIDToStages -= jobID + } + + // write log information to log file, withTime parameter controls whether to recored + // time stamp for the information + protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { + var writeInfo = info + if (withTime) { + val date = new Date(System.currentTimeMillis()) + writeInfo = DATE_FORMAT.format(date) + ": " +info + } + jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) + } + + protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) = + stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) + + protected def buildJobDep(jobID: Int, stage: Stage) { + if (stage.priority == jobID) { + jobIDToStages.get(jobID) match { + case Some(stageList) => stageList += stage + case None => val stageList = new ListBuffer[Stage] + stageList += stage + jobIDToStages += (jobID -> stageList) + } + stageIDToJobID += (stage.id -> jobID) + stage.parents.foreach(buildJobDep(jobID, _)) + } + } + + protected def recordStageDep(jobID: Int) { + def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { + var rddList = new ListBuffer[RDD[_]] + rddList += rdd + rdd.dependencies.foreach{ dep => dep match { + case shufDep: ShuffleDependency[_,_] => + case _ => rddList ++= getRddsInStage(dep.rdd) + } + } + rddList + } + jobIDToStages.get(jobID).foreach {_.foreach { stage => + var depRddDesc: String = "" + getRddsInStage(stage.rdd).foreach { rdd => + depRddDesc += rdd.id + "," + } + var depStageDesc: String = "" + stage.parents.foreach { stage => + depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" + } + jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + + depRddDesc.substring(0, depRddDesc.length - 1) + ")" + + " STAGE_DEP=" + depStageDesc, false) + } + } + } + + // generate indents and convert to String + protected def indentString(indent: Int) = { + val sb = new StringBuilder() + for (i <- 1 to indent) { + sb.append(" ") + } + sb.toString() + } + + protected def getRddName(rdd: RDD[_]) = { + var rddName = rdd.getClass.getName + if (rdd.name != null) { + rddName = rdd.name + } + rddName + } + + protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { + val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")" + jobLogInfo(jobID, indentString(indent) + rddInfo, false) + rdd.dependencies.foreach{ dep => dep match { + case shufDep: ShuffleDependency[_,_] => + val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId + jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) + case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1) + } + } + } + + protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) { + var stageInfo: String = "" + if (stage.isShuffleMap) { + stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + + stage.shuffleDep.get.shuffleId + }else{ + stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE" + } + if (stage.priority == jobID) { + jobLogInfo(jobID, indentString(indent) + stageInfo, false) + recordRddInStageGraph(jobID, stage.rdd, indent) + stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2)) + } else + jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false) + } + + // record task metrics into job log files + protected def recordTaskMetrics(stageID: Int, status: String, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { + val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + + " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname + val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val readMetrics = + taskMetrics.shuffleReadMetrics match { + case Some(metrics) => + " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + + " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + + " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + + " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + case None => "" + } + val writeMetrics = + taskMetrics.shuffleWriteMetrics match { + case Some(metrics) => + " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case None => "" + } + stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) + } + + override def onStageSubmitted(stage: Stage, info: String = "") { + eventQueue.put(JobLoggerOnStageSubmitted(stage, info)) + } + + protected def processStageSubmittedEvent(stage: Stage, info: String) { + stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED " + info) + } + + override def onStageCompleted(stageCompleted: StageCompleted) { + eventQueue.put(JobLoggerOnStageCompleted(stageCompleted)) + } + + protected def processStageCompletedEvent(stageCompleted: StageCompleted) { + stageLogInfo(stageCompleted.stageInfo.stage.id, "STAGE_ID=" + + stageCompleted.stageInfo.stage.id + " STATUS=COMPLETED") + + } + + override def onTaskEnd(event: CompletionEvent) { + eventQueue.put(JobLoggerOnTaskEnd(event)) + } + + protected def processTaskEndEvent(event: CompletionEvent) { + var taskStatus = "" + event.task match { + case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" + case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" + } + event.reason match { + case Success => taskStatus += " STATUS=SUCCESS" + recordTaskMetrics(event.task.stageId, taskStatus, event.taskInfo, event.taskMetrics) + case Resubmitted => + taskStatus += " STATUS=RESUBMITTED TID=" + event.taskInfo.taskId + + " STAGE_ID=" + event.task.stageId + stageLogInfo(event.task.stageId, taskStatus) + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + taskStatus += " STATUS=FETCHFAILED TID=" + event.taskInfo.taskId + " STAGE_ID=" + + event.task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + mapId + " REDUCE_ID=" + reduceId + stageLogInfo(event.task.stageId, taskStatus) + case OtherFailure(message) => + taskStatus += " STATUS=FAILURE TID=" + event.taskInfo.taskId + + " STAGE_ID=" + event.task.stageId + " INFO=" + message + stageLogInfo(event.task.stageId, taskStatus) + case _ => + } + } + + override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { + eventQueue.put(JobLoggerOnJobEnd(job, event)) + } + + protected def processJobEndEvent(job: ActiveJob, event: SparkListenerEvents) { + var info = "JOB_ID=" + job.runId + " STATUS=" + var validEvent = true + event match { + case SparkListenerJobSuccess => info += "SUCCESS" + case SparkListenerJobFailed(failedStage) => + info += "FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + failedStage.id + case SparkListenerJobCancelled(reason) => info += "CANCELLED REASON=" + reason + case _ => validEvent = false + } + if (validEvent) { + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + } + + protected def recordJobProperties(jobID: Int, properties: Properties) { + if(properties != null) { + val annotation = properties.getProperty("spark.job.annotation", "") + jobLogInfo(jobID, annotation, false) + } + } + + override def onJobStart(job: ActiveJob, properties: Properties = null) { + eventQueue.put(JobLoggerOnJobStart(job, properties)) + } + + protected def processJobStartEvent(job: ActiveJob, properties: Properties) { + createLogWriter(job.runId) + recordJobProperties(job.runId, properties) + buildJobDep(job.runId, job.finalStage) + recordStageDep(job.runId) + recordStageDepGraph(job.runId, job.finalStage) + jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED") + } +} diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index a65140b145..9cf7f3ffc0 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -1,27 +1,54 @@ package spark.scheduler +import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging} +import spark.{Utils, Logging, SparkContext, TaskEndReason} import spark.executor.TaskMetrics trait SparkListener { /** * called when a stage is completed, with information on the completed stage */ - def onStageCompleted(stageCompleted: StageCompleted) + def onStageCompleted(stageCompleted: StageCompleted) { } + + /** + * called when a stage is submitted + */ + def onStageSubmitted(stage: Stage, info: String = "") { } + + /** + * called when a task ends + */ + def onTaskEnd(event: CompletionEvent) { } + + /** + * called when a job starts + */ + def onJobStart(job: ActiveJob, properties: Properties = null) { } + + /** + * called when a job ends + */ + def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { } + } sealed trait SparkListenerEvents case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case object SparkListenerJobSuccess extends SparkListenerEvents + +case class SparkListenerJobFailed(failedStage: Stage) extends SparkListenerEvents + +case class SparkListenerJobCancelled(reason: String) extends SparkListenerEvents /** * Simple SparkListener that logs a few summary statistics when each stage completes */ class StatsReportListener extends SparkListener with Logging { - def onStageCompleted(stageCompleted: StageCompleted) { + override def onStageCompleted(stageCompleted: StageCompleted) { import spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted this.logInfo("Finished stage: " + stageCompleted.stageInfo) diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..34fd8b995e --- /dev/null +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,105 @@ +package spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import spark._ +import spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) + + joblogger.onStageSubmitted(rootStage) + joblogger.getEventQueue.size should be (1) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(event: CompletionEvent) = onTaskEndCount += 1 + override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) = onJobEndCount += 1 + override def onJobStart(job: ActiveJob, properties: Properties) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stage: Stage, info: String = "") = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala index 42a87d8b90..48aa67c543 100644 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala @@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc class SaveStageInfo extends SparkListener { val stageInfos = mutable.Buffer[StageInfo]() - def onStageCompleted(stage: StageCompleted) { + override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stageInfo } } -- cgit v1.2.3 From 4fd86e0e10149ad1803831a308a056c7105cbe67 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 15:45:47 +0800 Subject: delete test code for joblogger in SparkContext --- core/src/main/scala/spark/SparkContext.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b67a2066c8..70a9d7698c 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -48,7 +48,6 @@ import spark.scheduler.local.LocalScheduler import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo} import spark.util.{MetadataCleaner, TimeStampedHashMap} -import spark.scheduler.JobLogger /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -510,7 +509,7 @@ class SparkContext( def addSparkListener(listener: SparkListener) { dagScheduler.sparkListeners += listener } - addSparkListener(new JobLogger) + /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. -- cgit v1.2.3 From ade822011d44bd43e9ac78c1d29ec924a1f6e8e7 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Sat, 8 Jun 2013 16:26:45 +0800 Subject: not check return value of eventQueue.take --- .../src/main/scala/spark/scheduler/JobLogger.scala | 28 ++++++++++------------ 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index f87acfd0b6..46b9fa974b 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -48,21 +48,19 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { override def run() { while (true) { val event = eventQueue.take - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - event match { - case JobLoggerOnJobStart(job, info) => - processJobStartEvent(job, info) - case JobLoggerOnStageSubmitted(stage, info) => - processStageSubmittedEvent(stage, info) - case JobLoggerOnStageCompleted(stageCompleted) => - processStageCompletedEvent(stageCompleted) - case JobLoggerOnJobEnd(job, event) => - processJobEndEvent(job, event) - case JobLoggerOnTaskEnd(event) => - processTaskEndEvent(event) - case _ => - } + logDebug("Got event of type " + event.getClass.getName) + event match { + case JobLoggerOnJobStart(job, info) => + processJobStartEvent(job, info) + case JobLoggerOnStageSubmitted(stage, info) => + processStageSubmittedEvent(stage, info) + case JobLoggerOnStageCompleted(stageCompleted) => + processStageCompletedEvent(stageCompleted) + case JobLoggerOnJobEnd(job, event) => + processJobEndEvent(job, event) + case JobLoggerOnTaskEnd(event) => + processTaskEndEvent(event) + case _ => } } } -- cgit v1.2.3 From d1bbcebae580220076ceaa65f84dcf984ab51a16 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 8 Jun 2013 16:58:42 -0700 Subject: Adding compression to Hadoop save functions --- core/src/main/scala/spark/PairRDDFunctions.scala | 39 +++++++++++++++++- core/src/main/scala/spark/RDD.scala | 9 ++++ .../scala/spark/SequenceFileRDDFunctions.scala | 15 ++++--- core/src/test/scala/spark/FileSuite.scala | 48 ++++++++++++++++++++++ 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 2b0e697337..9bf1227d65 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -10,6 +10,8 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.FileOutputCommitter import org.apache.hadoop.mapred.FileOutputFormat import org.apache.hadoop.mapred.HadoopWriter @@ -515,6 +517,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress the result with the + * supplied codec. + */ + def saveAsHadoopFile[F <: OutputFormat[K, V]]( + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + } + /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. @@ -574,6 +586,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( jobCommitter.cleanupJob(jobTaskContext) } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. Compress with the supplied codec. + */ + def saveAsHadoopFile( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[_ <: OutputFormat[_, _]], + codec: Class[_ <: CompressionCodec]) { + saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, + new JobConf(self.context.hadoopConfiguration), Some(codec)) + } + /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. @@ -583,11 +609,22 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration), + codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) + codec match { + case Some(c) => { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + } + case _ => + } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) saveAsHadoopDataset(conf) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..e5995bea22 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -7,6 +7,7 @@ import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextOutputFormat @@ -730,6 +731,14 @@ abstract class RDD[T: ClassManifest]( .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) } + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { + this.map(x => (NullWritable.get(), new Text(x.toString))) + .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) + } + /** * Save this RDD as a SequenceFile of serialized objects. */ diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index 518034e07b..2911f9036e 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.OutputCommitter import org.apache.hadoop.mapred.FileOutputCommitter +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.Writable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.BytesWritable @@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported * file system. */ - def saveAsSequenceFile(path: String) { + def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { def anyToWritable[U <% Writable](u: U): Writable = u val keyClass = getWritableClass[K] @@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) val format = classOf[SequenceFileOutputFormat[Writable, Writable]] + val jobConf = new JobConf(self.context.hadoopConfiguration) if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format) + self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format) + self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + path, keyClass, valueClass, format, jobConf, codec) } } } diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 91b48c7456..a5d2028591 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -7,6 +7,8 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + import SparkContext._ @@ -26,6 +28,29 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) } + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -37,6 +62,29 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() -- cgit v1.2.3 From 083a3485abdcda5913c2186c4a7930ac07b061c4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 11:49:33 -0700 Subject: Clean extra whitespace --- core/src/test/scala/spark/FileSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index a5d2028591..e61ff7793d 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -50,7 +50,6 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(compressedFile.length < normalFile.length) } - test("SequenceFiles") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() @@ -84,7 +83,6 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(compressedFile.length < normalFile.length) } - test("SequenceFile with writable key") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() -- cgit v1.2.3 From df592192e736edca9e382a7f92e15bead390ef65 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 18:09:24 -0700 Subject: Monads FTW --- core/src/main/scala/spark/PairRDDFunctions.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 9bf1227d65..15593db0d9 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -615,15 +615,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug conf.set("mapred.output.format.class", outputFormatClass.getName) - codec match { - case Some(c) => { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) - } - case _ => + for (c <- codec) { + conf.setCompressMapOutput(true) + conf.set("mapred.output.compress", "true") + conf.setMapOutputCompressorClass(c) + conf.set("mapred.output.compression.codec", c.getCanonicalName) + conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) } conf.setOutputCommitter(classOf[FileOutputCommitter]) FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) -- cgit v1.2.3 From ef14dc2e7736732932d4edceb3be8d81ba9f8bc7 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 9 Jun 2013 18:09:46 -0700 Subject: Adding Java-API version of compression codec --- .../main/scala/spark/api/java/JavaPairRDD.scala | 11 ++++++ .../main/scala/spark/api/java/JavaRDDLike.scala | 8 ++++ core/src/test/scala/spark/JavaAPISuite.java | 46 ++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 30084df4e2..76051597b6 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -6,6 +6,7 @@ import java.util.Comparator import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} @@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ + def saveAsHadoopFile[F <: OutputFormat[_, _]]( + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { + rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) + } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 9b74d1226f..76b14e2e04 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -4,6 +4,7 @@ import java.util.{List => JList} import scala.Tuple2 import scala.collection.JavaConversions._ +import org.apache.hadoop.io.compress.CompressionCodec import spark.{SparkContext, Partition, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} @@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + + /** + * Save this RDD as a compressed text file, using string representations of elements. + */ + def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = + rdd.saveAsTextFile(path, codec) + /** * Save this RDD as a SequenceFile of serialized objects. */ diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 93bb69b41c..6caa85119a 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -8,6 +8,7 @@ import java.util.*; import scala.Tuple2; import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -473,6 +474,19 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, readRDD.collect()); } + @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + @Test public void sequenceFile() { File tempDir = Files.createTempDir(); @@ -619,6 +633,38 @@ public class JavaAPISuite implements Serializable { }).collect().toString()); } + @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + System.out.println(outputDir); + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + @Test public void zip() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); -- cgit v1.2.3 From 190ec617997d621c11ed1aab662a6e3a06815d2f Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Mon, 10 Jun 2013 15:27:02 +0800 Subject: change code style and debug info --- core/src/main/scala/spark/scheduler/local/LocalScheduler.scala | 8 +++----- .../main/scala/spark/scheduler/local/LocalTaskSetManager.scala | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 69dacfc2bd..93d4318b29 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -34,8 +34,7 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I } def launchTask(tasks : Seq[TaskDescription]) { - for (task <- tasks) - { + for (task <- tasks) { freeCores -= 1 localScheduler.threadPool.submit(new Runnable { def run() { @@ -85,8 +84,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } schedulableBuilder.buildPools() - localActor = env.actorSystem.actorOf( - Props(new LocalActor(this, threads)), "Test") + localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } override def setListener(listener: TaskSchedulerListener) { @@ -109,7 +107,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val tasks = new ArrayBuffer[TaskDescription](freeCores) val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() for (manager <- sortedTaskSetQueue) { - logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) } var launchTask = false diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index f2e07d162a..70b69bb26f 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -91,7 +91,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas if (availableCpus > 0 && numFinished < numTasks) { findTask() match { case Some(index) => - logInfo(taskSet.tasks(index).toString) val taskId = sched.attemptId.getAndIncrement() val task = taskSet.tasks(index) val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) -- cgit v1.2.3 From fd6148c8b20bc051786ff574d3b8f3b5e79b391a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 10 Jun 2013 10:27:25 -0700 Subject: Removing print statement --- core/src/test/scala/spark/JavaAPISuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 6caa85119a..d306124fca 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -652,7 +652,6 @@ public class JavaAPISuite implements Serializable { }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); - System.out.println(outputDir); JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); -- cgit v1.2.3 From dc4073654b1707f115de30088938f6e53efda0ba Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 11 Jun 2013 00:08:02 -0400 Subject: Revert "Fix start-slave not passing instance number to spark-daemon." This reverts commit a674d67c0aebb940e3b816e2307206115baec175. --- bin/start-slave.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/start-slave.sh b/bin/start-slave.sh index dfcbc6981b..26b5b9d462 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -12,4 +12,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker 1 "$@" +"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" -- cgit v1.2.3 From db5bca08ff00565732946a9c0a0244a9f7021d82 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Wed, 12 Jun 2013 10:54:16 -0700 Subject: add a new top K method to RDD using a bounded priority queue --- core/src/main/scala/spark/RDD.scala | 24 +++++++++++ .../scala/spark/util/BoundedPriorityQueue.scala | 48 ++++++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 19 +++++++++ 3 files changed, 91 insertions(+) create mode 100644 core/src/main/scala/spark/util/BoundedPriorityQueue.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e6c0438d76..ec5e5e2433 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,6 +35,7 @@ import spark.rdd.ZippedPartitionsRDD2 import spark.rdd.ZippedPartitionsRDD3 import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel +import spark.util.BoundedPriorityQueue import SparkContext._ @@ -722,6 +723,29 @@ abstract class RDD[T: ClassManifest]( case _ => throw new UnsupportedOperationException("empty collection") } + /** + * Returns the top K elements from this RDD as defined by + * the specified implicit Ordering[T]. + * @param num the number of top elements to return + * @param ord the implicit ordering for T + * @return an array of top elements + */ + def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { + val topK = mapPartitions { items => + val queue = new BoundedPriorityQueue[T](num) + queue ++= items + Iterator(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + } + + val builder = Array.newBuilder[T] + builder.sizeHint(topK.size) + builder ++= topK + builder.result() + } + /** * Save this RDD as a text file, using string representations of elements. */ diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..53ee95a02e --- /dev/null +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,48 @@ +package spark.util + +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable + +/** + * Bounded priority queue. This class modifies the original PriorityQueue's + * add/offer methods such that only the top K elements are retained. The top + * K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A]) + extends JPriorityQueue[A](maxSize, ord) with Growable[A] { + + override def offer(a: A): Boolean = { + if (size < maxSize) super.offer(a) + else maybeReplaceLowest(a) + } + + override def add(a: A): Boolean = offer(a) + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach(add) + this + } + + override def +=(elem: A): this.type = { + add(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = peek() + if (head != null && ord.gt(a, head)) { + poll() + super.offer(a) + } else false + } +} + +object BoundedPriorityQueue { + import scala.collection.JavaConverters._ + implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala +} + diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 3f69e99780..67f3332d44 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -317,4 +317,23 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(sample.size === checkSample.size) for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + + test("top with predefined ordering") { + sc = new SparkContext("local", "test") + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK.sorted === nums.sorted.takeRight(5)) + } + + test("top with custom ordering") { + sc = new SparkContext("local", "test") + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } } -- cgit v1.2.3 From 3f96c6f27b08039fb7b8d295f5de2083544e979f Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 12 Jun 2013 17:20:05 -0700 Subject: Fixed jvmArgs in maven build. --- pom.xml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index ce77ba37c6..c893ec755e 100644 --- a/pom.xml +++ b/pom.xml @@ -60,7 +60,7 @@ 4.1.2 1.2.17 - 0m + 64m 512m @@ -395,10 +395,8 @@ -Xms64m -Xmx1024m - -XX:PermSize - ${PermGen} - -XX:MaxPermSize - ${MaxPermGen} + -XX:PermSize=${PermGen} + -XX:MaxPermSize=${MaxPermGen} -source -- cgit v1.2.3 From 967a6a699da7da007f51e59d085a357da5ec14da Mon Sep 17 00:00:00 2001 From: Mingfei Date: Thu, 13 Jun 2013 14:36:07 +0800 Subject: modify sparklister function interface according to comments --- .../main/scala/spark/scheduler/DAGScheduler.scala | 15 ++-- .../src/main/scala/spark/scheduler/JobLogger.scala | 89 +++++++++++----------- .../main/scala/spark/scheduler/SparkListener.scala | 38 +++++---- .../scala/spark/scheduler/JobLoggerSuite.scala | 10 +-- 4 files changed, 79 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 43dd7d6534..e281e5a8db 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,7 +289,7 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - sparkListeners.foreach(_.onJobStart(job, properties)) + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -312,7 +312,7 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => - sparkListeners.foreach(_.onTaskEnd(completion)) + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -323,8 +323,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - val JobCancelEvent = new SparkListenerJobCancelled("SPARKCONTEXT_SHUTDOWN") - sparkListeners.foreach(_.onJobEnd(job, JobCancelEvent)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobCancelled(job, + "SPARKCONTEXT_SHUTDOWN"))) } return true } @@ -472,7 +472,7 @@ class DAGScheduler( } } if (tasks.size > 0) { - sparkListeners.foreach(_.onStageSubmitted(stage, "TASKS_SIZE=" + tasks.size)) + sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size))) logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) @@ -527,7 +527,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - sparkListeners.foreach(_.onJobEnd(job, SparkListenerJobSuccess)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobSuccess(job))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -671,8 +671,7 @@ class DAGScheduler( job.listener.jobFailed(new SparkException("Job failed: " + reason)) activeJobs -= job resultStageToJob -= resultStage - val jobFailedEvent = new SparkListenerJobFailed(failedStage) - sparkListeners.foreach(_.onJobEnd(job, jobFailedEvent)) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobFailed(job, failedStage))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 46b9fa974b..002c5826cb 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -15,13 +15,6 @@ import spark.scheduler.cluster.TaskInfo // used to record runtime information for each job, including RDD graph // tasks' start/stop shuffle information and information from outside -sealed trait JobLoggerEvent -case class JobLoggerOnJobStart(job: ActiveJob, properties: Properties) extends JobLoggerEvent -case class JobLoggerOnStageSubmitted(stage: Stage, info: String) extends JobLoggerEvent -case class JobLoggerOnStageCompleted(stageCompleted: StageCompleted) extends JobLoggerEvent -case class JobLoggerOnJobEnd(job: ActiveJob, event: SparkListenerEvents) extends JobLoggerEvent -case class JobLoggerOnTaskEnd(event: CompletionEvent) extends JobLoggerEvent - class JobLogger(val logDirName: String) extends SparkListener with Logging { private val logDir = if (System.getenv("SPARK_LOG_DIR") != null) @@ -32,7 +25,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { private val stageIDToJobID = new HashMap[Int, Int] private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - private val eventQueue = new LinkedBlockingQueue[JobLoggerEvent] + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] createLogDir() def this() = this(String.valueOf(System.currentTimeMillis())) @@ -50,15 +43,19 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { val event = eventQueue.take logDebug("Got event of type " + event.getClass.getName) event match { - case JobLoggerOnJobStart(job, info) => - processJobStartEvent(job, info) - case JobLoggerOnStageSubmitted(stage, info) => - processStageSubmittedEvent(stage, info) - case JobLoggerOnStageCompleted(stageCompleted) => - processStageCompletedEvent(stageCompleted) - case JobLoggerOnJobEnd(job, event) => - processJobEndEvent(job, event) - case JobLoggerOnTaskEnd(event) => + case SparkListenerJobStart(job, properties) => + processJobStartEvent(job, properties) + case SparkListenerStageSubmitted(stage, taskSize) => + processStageSubmittedEvent(stage, taskSize) + case StageCompleted(stageInfo) => + processStageCompletedEvent(stageInfo) + case SparkListenerJobSuccess(job) => + processJobEndEvent(job) + case SparkListenerJobFailed(job, failedStage) => + processJobEndEvent(job, failedStage) + case SparkListenerJobCancelled(job, reason) => + processJobEndEvent(job, reason) + case SparkListenerTaskEnd(event) => processTaskEndEvent(event) case _ => } @@ -225,26 +222,26 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) } - override def onStageSubmitted(stage: Stage, info: String = "") { - eventQueue.put(JobLoggerOnStageSubmitted(stage, info)) + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + eventQueue.put(stageSubmitted) } - protected def processStageSubmittedEvent(stage: Stage, info: String) { - stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED " + info) + protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) { + stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize) } override def onStageCompleted(stageCompleted: StageCompleted) { - eventQueue.put(JobLoggerOnStageCompleted(stageCompleted)) + eventQueue.put(stageCompleted) } - protected def processStageCompletedEvent(stageCompleted: StageCompleted) { - stageLogInfo(stageCompleted.stageInfo.stage.id, "STAGE_ID=" + - stageCompleted.stageInfo.stage.id + " STATUS=COMPLETED") + protected def processStageCompletedEvent(stageInfo: StageInfo) { + stageLogInfo(stageInfo.stage.id, "STAGE_ID=" + + stageInfo.stage.id + " STATUS=COMPLETED") } - override def onTaskEnd(event: CompletionEvent) { - eventQueue.put(JobLoggerOnTaskEnd(event)) + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + eventQueue.put(taskEnd) } protected def processTaskEndEvent(event: CompletionEvent) { @@ -273,24 +270,26 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { - eventQueue.put(JobLoggerOnJobEnd(job, event)) + override def onJobEnd(jobEnd: SparkListenerEvents) { + eventQueue.put(jobEnd) } - protected def processJobEndEvent(job: ActiveJob, event: SparkListenerEvents) { - var info = "JOB_ID=" + job.runId + " STATUS=" - var validEvent = true - event match { - case SparkListenerJobSuccess => info += "SUCCESS" - case SparkListenerJobFailed(failedStage) => - info += "FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + failedStage.id - case SparkListenerJobCancelled(reason) => info += "CANCELLED REASON=" + reason - case _ => validEvent = false - } - if (validEvent) { - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } + protected def processJobEndEvent(job: ActiveJob) { + val info = "JOB_ID=" + job.runId + " STATUS=SUCCESS" + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + + protected def processJobEndEvent(job: ActiveJob, failedStage: Stage) { + val info = "JOB_ID=" + job.runId + " STATUS=FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" + + failedStage.id + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) + } + protected def processJobEndEvent(job: ActiveJob, reason: String) { + var info = "JOB_ID=" + job.runId + " STATUS=CANCELLED REASON=" + reason + jobLogInfo(job.runId, info) + closeLogWriter(job.runId) } protected def recordJobProperties(jobID: Int, properties: Properties) { @@ -300,8 +299,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - override def onJobStart(job: ActiveJob, properties: Properties = null) { - eventQueue.put(JobLoggerOnJobStart(job, properties)) + override def onJobStart(jobStart: SparkListenerJobStart) { + eventQueue.put(jobStart) } protected def processJobStartEvent(job: ActiveJob, properties: Properties) { diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index 9cf7f3ffc0..9265261dc1 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -6,6 +6,24 @@ import spark.util.Distribution import spark.{Utils, Logging, SparkContext, TaskEndReason} import spark.executor.TaskMetrics + +sealed trait SparkListenerEvents + +case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents + +case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents + +case class SparkListenerTaskEnd(event: CompletionEvent) extends SparkListenerEvents + +case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) + extends SparkListenerEvents + +case class SparkListenerJobSuccess(job: ActiveJob) extends SparkListenerEvents + +case class SparkListenerJobFailed(job: ActiveJob, failedStage: Stage) extends SparkListenerEvents + +case class SparkListenerJobCancelled(job: ActiveJob, reason: String) extends SparkListenerEvents + trait SparkListener { /** * called when a stage is completed, with information on the completed stage @@ -15,35 +33,25 @@ trait SparkListener { /** * called when a stage is submitted */ - def onStageSubmitted(stage: Stage, info: String = "") { } - + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + /** * called when a task ends */ - def onTaskEnd(event: CompletionEvent) { } + def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } /** * called when a job starts */ - def onJobStart(job: ActiveJob, properties: Properties = null) { } + def onJobStart(jobStart: SparkListenerJobStart) { } /** * called when a job ends */ - def onJobEnd(job: ActiveJob, event: SparkListenerEvents) { } + def onJobEnd(jobEnd: SparkListenerEvents) { } } -sealed trait SparkListenerEvents - -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents - -case object SparkListenerJobSuccess extends SparkListenerEvents - -case class SparkListenerJobFailed(failedStage: Stage) extends SparkListenerEvents - -case class SparkListenerJobCancelled(reason: String) extends SparkListenerEvents - /** * Simple SparkListener that logs a few summary statistics when each stage completes */ diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala index 34fd8b995e..a654bf3ffd 100644 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -40,7 +40,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID) val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) - joblogger.onStageSubmitted(rootStage) + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4)) joblogger.getEventQueue.size should be (1) joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) parentRdd.setName("MyRDD") @@ -86,11 +86,11 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers var onJobStartCount = 0 var onStageCompletedCount = 0 var onStageSubmittedCount = 0 - override def onTaskEnd(event: CompletionEvent) = onTaskEndCount += 1 - override def onJobEnd(job: ActiveJob, event: SparkListenerEvents) = onJobEndCount += 1 - override def onJobStart(job: ActiveJob, properties: Properties) = onJobStartCount += 1 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerEvents) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 - override def onStageSubmitted(stage: Stage, info: String = "") = onStageSubmittedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 } sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } -- cgit v1.2.3 From b5b12823faf62766d880e497c90b44b21f5a433a Mon Sep 17 00:00:00 2001 From: Rohit Rai Date: Thu, 13 Jun 2013 14:05:46 +0530 Subject: Fixing the style as per feedback --- .../main/scala/spark/examples/CassandraTest.scala | 72 +++++++++++----------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala index 2cc62b9fe9..0fe1833e83 100644 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/spark/examples/CassandraTest.scala @@ -1,9 +1,11 @@ package spark.examples import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.{ColumnFamilyOutputFormat, ConfigHelper, ColumnFamilyInputFormat} +import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.thrift._ -import spark.{RDD, SparkContext} +import spark.SparkContext import spark.SparkContext._ import java.nio.ByteBuffer import java.util.SortedMap @@ -12,9 +14,9 @@ import org.apache.cassandra.utils.ByteBufferUtil import scala.collection.JavaConversions._ - /* - * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra support for Hadoop. + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra + * support for Hadoop. * * To run this example, run this file with the following command params - * @@ -26,32 +28,31 @@ import scala.collection.JavaConversions._ * 1. You have already created a keyspace called casDemo and it has a column family named Words * 2. There are column family has a column named "para" which has test content. * - * You can create the content by running the following script at the bottom of this file with cassandra-cli. + * You can create the content by running the following script at the bottom of this file with + * cassandra-cli. * */ object CassandraTest { + def main(args: Array[String]) { - //Get a SparkContext + // Get a SparkContext val sc = new SparkContext(args(0), "casDemo") - //Build the job configuration with ConfigHelper provided by Cassandra + // Build the job configuration with ConfigHelper provided by Cassandra val job = new Job() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) - ConfigHelper.setInputInitialAddress(job.getConfiguration(), args(1)) - - ConfigHelper.setInputRpcPort(job.getConfiguration(), args(2)) - - ConfigHelper.setOutputInitialAddress(job.getConfiguration(), args(1)) - - ConfigHelper.setOutputRpcPort(job.getConfiguration(), args(2)) + val host: String = args(1) + val port: String = args(2) + ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setInputRpcPort(job.getConfiguration(), port) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") - ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") - val predicate = new SlicePredicate() val sliceRange = new SliceRange() sliceRange.setStart(Array.empty[Byte]) @@ -60,11 +61,11 @@ object CassandraTest { ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - //Make a new Hadoop RDD - val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + // Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD( + job.getConfiguration(), classOf[ColumnFamilyInputFormat], classOf[ByteBuffer], classOf[SortedMap[ByteBuffer, IColumn]]) @@ -76,7 +77,7 @@ object CassandraTest { } } - //Lets get the word count in paras + // Lets get the word count in paras val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) counts.collect().foreach { @@ -95,20 +96,17 @@ object CassandraTest { colCount.setValue(ByteBufferUtil.bytes(count.toLong)) colCount.setTimestamp(System.currentTimeMillis) - val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) - mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) - } } @@ -117,16 +115,20 @@ create keyspace casDemo; use casDemo; create column family WordCount with comparator = UTF8Type; -update column family WordCount with column_metadata = [{column_name: word, validation_class: UTF8Type}, {column_name: wcount, validation_class: LongType}]; +update column family WordCount with column_metadata = + [{column_name: word, validation_class: UTF8Type}, + {column_name: wcount, validation_class: LongType}]; create column family Words with comparator = UTF8Type; -update column family Words with column_metadata = [{column_name: book, validation_class: UTF8Type}, {column_name: para, validation_class: UTF8Type}]; +update column family Words with column_metadata = + [{column_name: book, validation_class: UTF8Type}, + {column_name: para, validation_class: UTF8Type}]; assume Words keys as utf8; set Words['3musk001']['book'] = 'The Three Musketeers'; -set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market town of - Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market + town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to be in as perfect a state of revolution as if the Huguenots had just made a second La Rochelle of it. Many citizens, seeing the women flying toward the High Street, leaving their children crying at the open doors, @@ -136,8 +138,8 @@ set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625 every minute, a compact group, vociferous and full of curiosity.'; set Words['3musk002']['book'] = 'The Three Musketeers'; -set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without some city - or other registering in its archives an event of this kind. There were +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without + some city or other registering in its archives an event of this kind. There were nobles, who made war against each other; there was the king, who made war against the cardinal; there was Spain, which made war against the king. Then, in addition to these concealed or public, secret or open @@ -152,8 +154,8 @@ set Words['3musk002']['para'] = 'In those times panics were common, and few days cause of the hubbub was apparent to all'; set Words['3musk003']['book'] = 'The Three Musketeers'; -set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however large - the sum may be; but you ought also to endeavor to perfect yourself in +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however + large the sum may be; but you ought also to endeavor to perfect yourself in the exercises becoming a gentleman. I will write a letter today to the Director of the Royal Academy, and tomorrow he will admit you without any expense to yourself. Do not refuse this little service. Our @@ -165,8 +167,8 @@ set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means yo set Words['thelostworld001']['book'] = 'The Lost World'; -set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined against the - red curtain. How beautiful she was! And yet how aloof! We had been +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined + against the red curtain. How beautiful she was! And yet how aloof! We had been friends, quite good friends; but never could I get beyond the same comradeship which I might have established with one of my fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, @@ -180,8 +182,8 @@ set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profil as that--or had inherited it in that race memory which we call instinct.'; set Words['thelostworld002']['book'] = 'The Lost World'; -set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, red-headed news - editor, and I rather hoped that he liked me. Of course, Beaumont was +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, + red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was the real boss; but he lived in the rarefied atmosphere of some Olympian height from which he could distinguish nothing smaller than an international crisis or a split in the Cabinet. Sometimes we saw him -- cgit v1.2.3 From 1d9f0df0652f455145d2dfed43a9407df6de6c43 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 13 Jun 2013 14:46:25 -0700 Subject: Fix some comments and style --- core/src/main/java/spark/network/netty/FileClient.java | 2 +- core/src/main/scala/spark/network/netty/ShuffleCopier.scala | 8 ++++---- core/src/main/scala/spark/storage/BlockFetcherIterator.scala | 6 +----- core/src/main/scala/spark/storage/DiskStore.scala | 3 +-- core/src/test/scala/spark/ShuffleSuite.scala | 3 +-- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java index 517772202f..a4bb4bc701 100644 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ b/core/src/main/java/spark/network/netty/FileClient.java @@ -30,7 +30,7 @@ class FileClient { .channel(OioSocketChannel.class) .option(ChannelOption.SO_KEEPALIVE, true) .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) // Disable connect timeout + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) .handler(new FileClientChannelInitializer(handler)); } diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index afb2cdbb3a..8d5194a737 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -18,8 +18,9 @@ private[spark] class ShuffleCopier extends Logging { resultCollectCallback: (String, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val fc = new FileClient(handler, - System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt) + val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt + val fc = new FileClient(handler, connectTimeout) + try { fc.init() fc.connect(host, port) @@ -29,8 +30,7 @@ private[spark] class ShuffleCopier extends Logging { } catch { // Handle any socket-related exceptions in FileClient case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + - " failed", e) + logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) handler.handleError(blockId) } } diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index bb78207c9f..bec876213e 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -322,11 +322,7 @@ object BlockFetcherIterator { override def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() - // if all the results has been retrieved, shutdown the copiers - // NO need to stop the copiers if we got all the blocks ? - // if (resultsGotten == _numBlocksToFetch && copiers != null) { - // stopCopiers() - // } + // If all the results has been retrieved, copiers will exit automatically (result.blockId, if (result.failed) None else Some(result.deserialize())) } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 0af6e4a359..15ab840155 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -212,10 +212,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (!allowAppendExisting && file.exists()) { // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task ? + // was rescheduled on the same machine as the old task. logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") file.delete() - // throw new Exception("File for block " + blockId + " already exists on disk: " + file) } file } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 33b02fff80..1916885a73 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -376,8 +376,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val a = sc.parallelize(1 to 4, NUM_BLOCKS) val b = a.map(x => (x, x*2)) - // NOTE: The default Java serializer doesn't create zero-sized blocks. - // So, use Kryo + // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD(b, new HashPartitioner(10)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId -- cgit v1.2.3 From 44b8dbaedeb88f12ea911968c524883805f7ad95 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Thu, 13 Jun 2013 16:23:15 -0700 Subject: use Iterator.single(elem) instead of Iterator(elem) for improved performance based on scaladocs --- core/src/main/scala/spark/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ec5e5e2433..bc9c17d507 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -734,7 +734,7 @@ abstract class RDD[T: ClassManifest]( val topK = mapPartitions { items => val queue = new BoundedPriorityQueue[T](num) queue ++= items - Iterator(queue) + Iterator.single(queue) }.reduce { (queue1, queue2) => queue1 ++= queue2 queue1 -- cgit v1.2.3 From 93b3f5e535c509a017a433b72249fc49c79d4a0f Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Thu, 13 Jun 2013 16:26:35 -0700 Subject: drop unneeded ClassManifest implicit --- core/src/main/scala/spark/util/BoundedPriorityQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala index 53ee95a02e..ef01beaea5 100644 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -8,7 +8,7 @@ import scala.collection.generic.Growable * add/offer methods such that only the top K elements are retained. The top * K elements are defined by an implicit Ordering[A]. */ -class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A], mf: ClassManifest[A]) +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) extends JPriorityQueue[A](maxSize, ord) with Growable[A] { override def offer(a: A): Boolean = { -- cgit v1.2.3 From 6738178d0daf1bbe7441db7c0c773a29bb2ec388 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 13 Jun 2013 23:59:42 -0700 Subject: SPARK-772: groupByKey should disable map side combine. --- core/src/main/scala/spark/PairRDDFunctions.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 15593db0d9..fa4bbfc76f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -19,7 +19,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil} import spark.partial.BoundedDouble import spark.partial.PartialResult @@ -187,11 +187,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + // groupByKey shouldn't use map side combine because map side combine does not + // reduce the amount of data shuffled and requires all map side data be inserted + // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner) + createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } -- cgit v1.2.3 From 2cc188fd546fa061812f9fd4f72cf936bd01a0e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 14 Jun 2013 00:10:54 -0700 Subject: SPARK-774: cogroup should also disable map side combine by default --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 7599ba1a02..8966f9f86e 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext} +import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output. - * @param mapSideCombine flag indicating whether to merge values before shuffle step. + * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag + * is on, Spark does an extra pass over the data on the map side to merge + * all values belonging to the same key together. This can reduce the amount + * of data shuffled if and only if the number of distinct keys is very small, + * and the ratio of key size to value size is also very small. */ class CoGroupedRDD[K]( @transient var rdds: Seq[RDD[(K, _)]], part: Partitioner, - val mapSideCombine: Boolean = true, + val mapSideCombine: Boolean = false, val serializerClass: String = null) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { -- cgit v1.2.3 From 53add598f2fe09759a0df1e08f87f70503f808c5 Mon Sep 17 00:00:00 2001 From: Andrew xia Date: Sat, 15 Jun 2013 01:34:17 +0800 Subject: Update LocalSchedulerSuite to avoid using sleep for task launch --- .../spark/scheduler/LocalSchedulerSuite.scala | 83 +++++++++++++++------- 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala index 37d14ed113..8bd813fd14 100644 --- a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala @@ -9,9 +9,7 @@ import spark.scheduler.cluster._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.{ConcurrentMap, HashMap} import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger - +import java.util.concurrent.CountDownLatch import java.util.Properties class Lock() { @@ -35,9 +33,19 @@ class Lock() { object TaskThreadInfo { val threadToLock = HashMap[Int, Lock]() val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() } - +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { @@ -45,22 +53,23 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { TaskThreadInfo.threadToRunning(threadIndex) = false val nums = sc.parallelize(threadIndex to threadIndex, 1) TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) new Thread { - if (poolName != null) { - sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToLock(number).jobWait() - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - TaskThreadInfo.threadToRunning(threadIndex) = false - } + if (poolName != null) { + sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } }.start() - Thread.sleep(2000) } test("Local FIFO scheduler end-to-end test") { @@ -69,11 +78,24 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { val sem = new Semaphore(0) createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? createThread(5,null,sc,sem) + Thread.sleep(1000) createThread(6,null,sc,sem) + Thread.sleep(1000) + assert(TaskThreadInfo.threadToRunning(1) === true) assert(TaskThreadInfo.threadToRunning(2) === true) assert(TaskThreadInfo.threadToRunning(3) === true) @@ -82,8 +104,8 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(6) === false) TaskThreadInfo.threadToLock(1).jobFinished() - Thread.sleep(1000) - + TaskThreadInfo.threadToStarted(5).await() + assert(TaskThreadInfo.threadToRunning(1) === false) assert(TaskThreadInfo.threadToRunning(2) === true) assert(TaskThreadInfo.threadToRunning(3) === true) @@ -92,7 +114,7 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(6) === false) TaskThreadInfo.threadToLock(3).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(6).await() assert(TaskThreadInfo.threadToRunning(1) === false) assert(TaskThreadInfo.threadToRunning(2) === true) @@ -116,23 +138,31 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { System.setProperty("spark.fairscheduler.allocation.file", xmlPath) createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() assert(TaskThreadInfo.threadToRunning(10) === true) assert(TaskThreadInfo.threadToRunning(20) === true) assert(TaskThreadInfo.threadToRunning(30) === true) createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() assert(TaskThreadInfo.threadToRunning(11) === true) assert(TaskThreadInfo.threadToRunning(21) === true) assert(TaskThreadInfo.threadToRunning(31) === true) createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() createThread(32,"3",sc,sem) assert(TaskThreadInfo.threadToRunning(12) === true) @@ -140,20 +170,25 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { assert(TaskThreadInfo.threadToRunning(32) === false) TaskThreadInfo.threadToLock(10).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(32).await() + assert(TaskThreadInfo.threadToRunning(32) === true) + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. createThread(23,"2",sc,sem) createThread(33,"3",sc,sem) + Thread.sleep(1000) TaskThreadInfo.threadToLock(11).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(23).await() assert(TaskThreadInfo.threadToRunning(23) === true) assert(TaskThreadInfo.threadToRunning(33) === false) TaskThreadInfo.threadToLock(12).jobFinished() - Thread.sleep(1000) + TaskThreadInfo.threadToStarted(33).await() assert(TaskThreadInfo.threadToRunning(33) === true) -- cgit v1.2.3 From e8801d44900153eae6412963d2f3e2f19bfdc4e9 Mon Sep 17 00:00:00 2001 From: ryanlecompte Date: Fri, 14 Jun 2013 23:39:05 -0700 Subject: use delegation for BoundedPriorityQueue, add Java API --- core/src/main/scala/spark/RDD.scala | 9 ++---- core/src/main/scala/spark/api/java/JavaRDD.scala | 1 - .../main/scala/spark/api/java/JavaRDDLike.scala | 27 ++++++++++++++++- .../scala/spark/util/BoundedPriorityQueue.scala | 35 ++++++++++------------ 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index bc9c17d507..4a4616c843 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -731,19 +731,14 @@ abstract class RDD[T: ClassManifest]( * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { - val topK = mapPartitions { items => + mapPartitions { items => val queue = new BoundedPriorityQueue[T](num) queue ++= items Iterator.single(queue) }.reduce { (queue1, queue2) => queue1 ++= queue2 queue1 - } - - val builder = Array.newBuilder[T] - builder.sizeHint(topK.size) - builder ++= topK - builder.result() + }.toArray } /** diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index eb81ed64cd..626b499454 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] { */ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = wrapRDD(rdd.subtract(other, p)) - } object JavaRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 9b74d1226f..3e9c779d7b 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -1,6 +1,6 @@ package spark.api.java -import java.util.{List => JList} +import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ @@ -351,4 +351,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toDebugString(): String = { rdd.toDebugString } + + /** + * Returns the top K elements from this RDD as defined by + * the specified Comparator[T]. + * @param num the number of top elements to return + * @param comp the comparator that defines the order + * @return an array of top elements + */ + def top(num: Int, comp: Comparator[T]): JList[T] = { + import scala.collection.JavaConversions._ + val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) + val arr: java.util.Collection[T] = topElems.toSeq + new java.util.ArrayList(arr) + } + + /** + * Returns the top K elements from this RDD using the + * natural ordering for T. + * @param num the number of top elements to return + * @return an array of top elements + */ + def top(num: Int): JList[T] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] + top(num, comp) + } } diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala index ef01beaea5..4bc5db8bb7 100644 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala @@ -1,30 +1,30 @@ package spark.util +import java.io.Serializable import java.util.{PriorityQueue => JPriorityQueue} import scala.collection.generic.Growable +import scala.collection.JavaConverters._ /** - * Bounded priority queue. This class modifies the original PriorityQueue's - * add/offer methods such that only the top K elements are retained. The top - * K elements are defined by an implicit Ordering[A]. + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. */ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) - extends JPriorityQueue[A](maxSize, ord) with Growable[A] { + extends Iterable[A] with Growable[A] with Serializable { - override def offer(a: A): Boolean = { - if (size < maxSize) super.offer(a) - else maybeReplaceLowest(a) - } + private val underlying = new JPriorityQueue[A](maxSize, ord) - override def add(a: A): Boolean = offer(a) + override def iterator: Iterator[A] = underlying.iterator.asScala override def ++=(xs: TraversableOnce[A]): this.type = { - xs.foreach(add) + xs.foreach { this += _ } this } override def +=(elem: A): this.type = { - add(elem) + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) this } @@ -32,17 +32,14 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) this += elem1 += elem2 ++= elems } + override def clear() { underlying.clear() } + private def maybeReplaceLowest(a: A): Boolean = { - val head = peek() + val head = underlying.peek() if (head != null && ord.gt(a, head)) { - poll() - super.offer(a) + underlying.poll() + underlying.offer(a) } else false } } -object BoundedPriorityQueue { - import scala.collection.JavaConverters._ - implicit def asIterable[A](queue: BoundedPriorityQueue[A]): Iterable[A] = queue.asScala -} - -- cgit v1.2.3 From 479442a9b913b08a64da4bd5848111d950105336 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sat, 15 Jun 2013 17:35:55 -0700 Subject: Add zeroLengthPartitions() test to make sure, e.g., StatCounter.scala can handle empty partitions without incorrectly returning NaN --- core/src/test/scala/spark/JavaAPISuite.java | 22 ++++++++++++++++++++++ project/plugins.sbt | 2 ++ 2 files changed, 24 insertions(+) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 93bb69b41c..3190a43e73 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -314,6 +314,28 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } + @Test + public void zeroLengthPartitions() { + // Create RDD with some consecutive empty partitions (including the "first" one) + JavaDoubleRDD rdd = sc + .parallelizeDoubles(Arrays.asList(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(new Function() { + @Override + public Boolean call(Double x) { + return x > 0.0; + } + }); + + // Run the partitions, including the consecutive empty ones, through StatCounter + StatCounter stats = rdd.stats(); + Assert.assertEquals(6.0, stats.sum(), 0.01); + Assert.assertEquals(6.0/2, rdd.mean(), 0.01); + Assert.assertEquals(1.0, rdd.variance(), 0.01); + Assert.assertEquals(1.0, rdd.stdev(), 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/project/plugins.sbt b/project/plugins.sbt index d4f2442872..25b812a28d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") + +libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test" -- cgit v1.2.3 From 5c886194e458c64fcf24066af351bde47dd8bf12 Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sun, 16 Jun 2013 01:23:48 -0700 Subject: Move zero-length partition testing from JavaAPISuite.java to PartitioningSuite.scala --- core/src/test/scala/spark/JavaAPISuite.java | 22 ---------------------- core/src/test/scala/spark/PartitioningSuite.scala | 21 +++++++++++++++++++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 3190a43e73..93bb69b41c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -314,28 +314,6 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } - @Test - public void zeroLengthPartitions() { - // Create RDD with some consecutive empty partitions (including the "first" one) - JavaDoubleRDD rdd = sc - .parallelizeDoubles(Arrays.asList(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) - .filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 0.0; - } - }); - - // Run the partitions, including the consecutive empty ones, through StatCounter - StatCounter stats = rdd.stats(); - Assert.assertEquals(6.0, stats.sum(), 0.01); - Assert.assertEquals(6.0/2, rdd.mean(), 0.01); - Assert.assertEquals(1.0, rdd.variance(), 0.01); - Assert.assertEquals(1.0, rdd.stdev(), 0.01); - - // Add other tests here for classes that should be able to handle empty partitions correctly - } - @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 60db759c25..e5745c81b3 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,10 +1,10 @@ package spark import org.scalatest.FunSuite - import scala.collection.mutable.ArrayBuffer - import SparkContext._ +import spark.util.StatCounter +import scala.math._ class PartitioningSuite extends FunSuite with LocalSparkContext { @@ -120,4 +120,21 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) } + + test("Zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + sc = new SparkContext("local", "test") + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } } -- cgit v1.2.3 From f91195cc150a3ead122046d14bd35b4fcf28c9cb Mon Sep 17 00:00:00 2001 From: Christopher Nguyen Date: Sun, 16 Jun 2013 01:29:53 -0700 Subject: Import just scala.math.abs rather than scala.math._ --- core/src/test/scala/spark/PartitioningSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index e5745c81b3..16f93e71a3 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.FunSuite import scala.collection.mutable.ArrayBuffer import SparkContext._ import spark.util.StatCounter -import scala.math._ +import scala.math.abs class PartitioningSuite extends FunSuite with LocalSparkContext { -- cgit v1.2.3 From fb6d733fa88aa124deecf155af40cc095ecca5b3 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Sun, 16 Jun 2013 22:32:55 +0000 Subject: update according to comments --- core/src/main/scala/spark/RDD.scala | 71 ++------------------------- core/src/main/scala/spark/rdd/PipedRDD.scala | 29 +++++------ core/src/test/scala/spark/PipedRDDSuite.scala | 13 +++-- 3 files changed, 24 insertions(+), 89 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a1c9604324..152f7be9bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -355,68 +355,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - env: Map[String, String], - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U], - delimiter: String): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, delimiter) - - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U]): RDD[String] = - new PipedRDD(this, command, Map[String, String](), transform, pipeContext, "\u0001") - - /** - * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. - */ - def pipe[U<: Seq[String]]( - command: String, - env: Map[String, String], - transform: (T,String => Unit) => Any, - pipeContext: Broadcast[U]): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, "\u0001") /** * Return an RDD created by piping elements to a forked external process. @@ -432,13 +370,12 @@ abstract class RDD[T: ClassManifest]( * external process with "^A" as the delimiter in the end of context data. Delimiter can also * be customized by the last parameter delimiter. */ - def pipe[U<: Seq[String]]( + def pipe( command: Seq[String], env: Map[String, String] = Map(), - transform: (T,String => Unit) => Any = null, - pipeContext: Broadcast[U] = null, - delimiter: String = "\u0001"): RDD[String] = - new PipedRDD(this, command, env, transform, pipeContext, delimiter) + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, printRDDElement) /** * Return a new RDD by applying a function to each partition of this RDD. diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d58aaae709..b2c07891ab 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -16,14 +16,12 @@ import spark.broadcast.Broadcast * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest, U <: Seq[String]]( +class PipedRDD[T: ClassManifest]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], - transform: (T, String => Unit) => Any, - pipeContext: Broadcast[U], - delimiter: String - ) + printPipeContext: (String => Unit) => Unit, + printRDDElement: (T, String => Unit) => Unit) extends RDD[String](prev) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -32,10 +30,9 @@ class PipedRDD[T: ClassManifest, U <: Seq[String]]( prev: RDD[T], command: String, envVars: Map[String, String] = Map(), - transform: (T, String => Unit) => Any = null, - pipeContext: Broadcast[U] = null, - delimiter: String = "\u0001") = - this(prev, PipedRDD.tokenize(command), envVars, transform, pipeContext, delimiter) + printPipeContext: (String => Unit) => Unit = null, + printRDDElement: (T, String => Unit) => Unit = null) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) override def getPartitions: Array[Partition] = firstParent[T].partitions @@ -64,17 +61,13 @@ class PipedRDD[T: ClassManifest, U <: Seq[String]]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - // input the pipeContext firstly - if ( pipeContext != null) { - for (elem <- pipeContext.value) { - out.println(elem) - } - // delimiter\n as the marker of the end of the pipeContext - out.println(delimiter) + // input the pipe context firstly + if ( printPipeContext != null) { + printPipeContext(out.println(_)) } for (elem <- firstParent[T].iterator(split, context)) { - if (transform != null) { - transform(elem, out.println(_)) + if (printRDDElement != null) { + printRDDElement(elem, out.println(_)) } else { out.println(elem) } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index d2852867de..ed075f93ec 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -22,9 +22,12 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { test("advanced pipe") { sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) - val piped = nums.pipe(Seq("cat"), Map[String, String](), - (i:Int, f: String=> Unit) => f(i + "_"), sc.broadcast(List("0"))) + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) val c = piped.collect() @@ -40,8 +43,10 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), Map[String, String](), (i:Tuple2[String, Seq[String]], f: String=> Unit) => - {for (e <- i._2){ f(e + "_")}}, sc.broadcast(List("0"))).collect() + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() assert(d.size === 8) assert(d(0) === "0") assert(d(1) === "\u0001") -- cgit v1.2.3 From 4508089fc342802a2f37fea6893cd47abd81fdd7 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Mon, 17 Jun 2013 05:23:46 +0000 Subject: refine comments and add sc.clean --- core/src/main/scala/spark/RDD.scala | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 05ff399a7b..223dcdc19d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -361,24 +361,30 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD created by piping elements to a forked external process. - * How each record in RDD is outputed to the process can be controled by providing a - * function trasnform(T, outputFunction: String => Unit). transform() will be called with - * the currnet record in RDD as the 1st parameter, and the function to output the record to - * the external process (like out.println()) as the 2nd parameter. - * Here's an example on how to pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the records: - * def tranform(record:(String, Seq[String]), f:String=>Unit) = for (e <- record._2){f(e)} - * pipeContext can be used to transfer additional context data to the external process - * besides the RDD. pipeContext is a broadcast Seq[String], each line would be piped to - * external process with "^A" as the delimiter in the end of context data. Delimiter can also - * be customized by the last parameter delimiter. + * The print behavior can be customized by providing two functions. + * + * @param command command to run in forked process. + * @param env environment variables to set. + * @param printPipeContext Before piping elements, this function is called as an oppotunity + * to pipe context data. Print line function (like out.println) will be + * passed as printPipeContext's parameter. + * @param printPipeContext Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} + * @return the result RDD */ def pipe( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = - new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, printRDDElement) + new PipedRDD(this, command, env, + if (printPipeContext ne null) sc.clean(printPipeContext) else null, + if (printRDDElement ne null) sc.clean(printRDDElement) else null) /** * Return a new RDD by applying a function to each partition of this RDD. -- cgit v1.2.3 From 1450296797e53f1a01166c885050091df9c96e2e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 17 Jun 2013 16:58:23 -0400 Subject: SPARK-781: Log the temp directory path when Spark says "Failed to create temp directory". --- core/src/main/scala/spark/Utils.scala | 4 +-- core/src/main/scala/spark/storage/DiskStore.scala | 34 +++++++++++------------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ec15326014..fd7b8cc8d5 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -116,8 +116,8 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory after " + maxAttempts + - " attempts!") + throw new IOException("Failed to create a temp directory under (" + root + ") after " + + maxAttempts + " attempts!") } try { dir = new File(root, "spark-" + UUID.randomUUID.toString) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index c7281200e7..9914beec99 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -82,15 +82,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def size(): Long = lastValidPosition } - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - var shuffleSender : ShuffleSender = null + private var shuffleSender : ShuffleSender = null // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid // having really large inodes at the top level. - val localDirs = createLocalDirs() - val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) addShutdownHook() @@ -99,7 +99,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) new DiskBlockObjectWriter(blockId, serializer, bufferSize) } - override def getSize(blockId: String): Long = { getFile(blockId).length() } @@ -232,8 +231,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map(rootDir => { - var foundLocalDir: Boolean = false + rootDirs.split(",").map { rootDir => + var foundLocalDir = false var localDir: File = null var localDirId: String = null var tries = 0 @@ -248,7 +247,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } } catch { case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) } } if (!foundLocalDir) { @@ -258,7 +257,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } logInfo("Created local directory at " + localDir) localDir - }) + } } private def addShutdownHook() { @@ -266,15 +265,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") - try { - localDirs.foreach { localDir => + localDirs.foreach { localDir => + try { if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) } - if (shuffleSender != null) { - shuffleSender.stop - } - } catch { - case t: Throwable => logError("Exception while deleting local spark dirs", t) + } + if (shuffleSender != null) { + shuffleSender.stop } } }) -- cgit v1.2.3 From be3c406edf06d5ab9da98097c28ce3eebc958b8e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 17 Jun 2013 17:07:51 -0400 Subject: Fixed the typo pointed out by Matei. --- core/src/main/scala/spark/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index fd7b8cc8d5..645c18541e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -116,7 +116,7 @@ private object Utils extends Logging { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory under (" + root + ") after " + + throw new IOException("Failed to create a temp directory (under " + root + ") after " + maxAttempts + " attempts!") } try { -- cgit v1.2.3 From 2ab311f4cee3f918dc28daaebd287b11c9f63429 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 18 Jun 2013 00:40:25 +0200 Subject: Removed second version of junit test plugin from plugins.sbt --- project/plugins.sbt | 2 -- 1 file changed, 2 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 25b812a28d..d4f2442872 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,5 +16,3 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") - -libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M4" % "test" -- cgit v1.2.3 From 1e9269c3eeeaa3a481b95521c703032ed84abd68 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 13 Jun 2013 10:46:22 +0800 Subject: reduce ZippedPartitionsRDD's getPreferredLocations complexity --- core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index dd9f3c2680..b234428ab2 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) // Remove exact match and then do host local match. - val otherNodePreferredLocations = rddSplitZip.map(x => { - x._1.preferredLocations(x._2).map(hostPort => { - val host = Utils.parseHostPort(hostPort)._1 - - if (exactMatchLocations.contains(host)) null else host - }).filter(_ != null) - }) - val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y)) + val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1) + val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1)) + .reduce((x, y) => x.intersect(y)) + val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) } otherNodeLocalLocations ++ exactMatchLocations } -- cgit v1.2.3 From 0a2a9bce1e83e891334985c29176c6426b8b1751 Mon Sep 17 00:00:00 2001 From: Gavin Li Date: Tue, 18 Jun 2013 21:30:13 +0000 Subject: fix typo and coding style --- core/src/main/scala/spark/RDD.scala | 14 +++++++------- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 223dcdc19d..709271d4eb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -368,13 +368,13 @@ abstract class RDD[T: ClassManifest]( * @param printPipeContext Before piping elements, this function is called as an oppotunity * to pipe context data. Print line function (like out.println) will be * passed as printPipeContext's parameter. - * @param printPipeContext Use this function to customize how to pipe elements. This function - * will be called with each RDD element as the 1st parameter, and the - * print line function (like out.println()) as the 2nd parameter. - * An example of pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * @param printRDDElement Use this function to customize how to pipe elements. This function + * will be called with each RDD element as the 1st parameter, and the + * print line function (like out.println()) as the 2nd parameter. + * An example of pipe the RDD data of groupBy() in a streaming way, + * instead of constructing a huge String to concat all the elements: + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} * @return the result RDD */ def pipe( diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index b2c07891ab..c0baf43d43 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -62,7 +62,7 @@ class PipedRDD[T: ClassManifest]( val out = new PrintWriter(proc.getOutputStream) // input the pipe context firstly - if ( printPipeContext != null) { + if (printPipeContext != null) { printPipeContext(out.println(_)) } for (elem <- firstParent[T].iterator(split, context)) { -- cgit v1.2.3 From 7902baddc797f86f5bdbcc966f5cd60545638bf7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 19 Jun 2013 13:34:30 +0200 Subject: Update ASM to version 4.0 --- core/pom.xml | 4 ++-- core/src/main/scala/spark/ClosureCleaner.scala | 11 +++++------ pom.xml | 6 +++--- project/SparkBuild.scala | 2 +- repl/src/main/scala/spark/repl/ExecutorClassLoader.scala | 3 +-- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index d8687bf991..88f0ed70f3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,8 +32,8 @@ compress-lzf - asm - asm-all + org.ow2.asm + asm com.google.protobuf diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 50d6a1c5c9..d5e7132ff9 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -5,8 +5,7 @@ import java.lang.reflect.Field import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, MethodVisitor, Type} -import org.objectweb.asm.commons.EmptyVisitor +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.objectweb.asm.Opcodes._ import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} @@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging { } } -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten } } -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null override def visit(version: Int, access: Int, name: String, sig: String, @@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new EmptyVisitor { + return new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/pom.xml b/pom.xml index c893ec755e..3bcb2a3f34 100644 --- a/pom.xml +++ b/pom.xml @@ -190,9 +190,9 @@ 0.8.4 - asm - asm-all - 3.3.1 + org.ow2.asm + asm + 4.0 com.google.protobuf diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 824af821f9..b1f3f9a2ea 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -148,7 +148,7 @@ object SparkBuild extends Build { "org.slf4j" % "slf4j-log4j12" % slf4jVersion, "commons-daemon" % "commons-daemon" % "1.0.10", "com.ning" % "compress-lzf" % "0.8.4", - "asm" % "asm-all" % "3.3.1", + "org.ow2.asm" % "asm" % "4.0", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", "com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty), diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala index 13d81ec1cf..0e9aa863b5 100644 --- a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala @@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.objectweb.asm._ -import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ @@ -83,7 +82,7 @@ extends ClassLoader(parent) { } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassAdapter(cv) { +extends ClassVisitor(ASM4, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) -- cgit v1.2.3 From ae7a5da6b31f5bf64f713b3d9bff6e441d8615b4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 20 Jun 2013 18:44:46 +0200 Subject: Fix some dependency issues in SBT build (same will be needed for Maven): - Exclude a version of ASM 3.x that comes from HBase - Don't use a special ASF repo for HBase - Update SLF4J version - Add sbt-dependency-graph plugin so we can easily find dependency trees --- project/SparkBuild.scala | 10 +++++----- project/plugins.sbt | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b1f3f9a2ea..24c8b734d0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -125,12 +125,13 @@ object SparkBuild extends Build { publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn - ) + ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings - val slf4jVersion = "1.6.1" + val slf4jVersion = "1.7.2" val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeAsm = ExclusionRule(organization = "asm") def coreSettings = sharedSettings ++ Seq( name := "spark-core", @@ -201,11 +202,10 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - resolvers ++= Seq("Apache HBase" at "https://repository.apache.org/content/repositories/releases"), libraryDependencies ++= Seq( "com.twitter" % "algebird-core_2.9.2" % "0.1.11", - "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty), + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm), "org.apache.cassandra" % "cassandra-all" % "1.2.5" exclude("com.google.guava", "guava") @@ -224,7 +224,7 @@ object SparkBuild extends Build { name := "spark-streaming", libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty), - "com.github.sgroschupf" % "zkclient" % "0.1", + "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), "com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index d4f2442872..f806e66481 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1") //resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns) //addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6") + +addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3") -- cgit v1.2.3 From 52407951541399e60a5292394b3a443a5e7ff281 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:38:23 +0800 Subject: edit according to comments --- core/src/main/scala/spark/RDD.scala | 6 +- core/src/main/scala/spark/Utils.scala | 10 +-- .../main/scala/spark/executor/TaskMetrics.scala | 2 +- .../main/scala/spark/scheduler/DAGScheduler.scala | 14 +++-- .../src/main/scala/spark/scheduler/JobLogger.scala | 72 ++++++++++------------ .../main/scala/spark/scheduler/SparkListener.scala | 25 ++++---- .../scala/spark/scheduler/JobLoggerSuite.scala | 2 +- 7 files changed, 62 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 8c0b7ca417..b17398953b 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -114,10 +114,10 @@ abstract class RDD[T: ClassManifest]( this } - /**User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo._4 + /** User-defined generator of this RDD*/ + var generator = Utils.getCallSiteInfo.firstUserClass - /**reset generator*/ + /** Reset generator*/ def setGenerator(_generator: String) = { generator = _generator } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1630b2b4b0..1cfaee79b1 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,13 +522,14 @@ private object Utils extends Logging { execute(command, new File(".")) } - + class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getCallSiteInfo = { + def getCallSiteInfo: CallSiteInfo = { val trace = Thread.currentThread.getStackTrace().filter( el => (!el.getMethodName.contains("getStackTrace"))) @@ -560,12 +561,13 @@ private object Utils extends Logging { } } } - (lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) + new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } def formatSparkCallSite = { val callSiteInfo = getCallSiteInfo - "%s at %s:%s".format(callSiteInfo._1, callSiteInfo._2, callSiteInfo._3) + "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, + callSiteInfo.firstUserLine) } /** * Try to find a free port to bind to on the local host. This should ideally never be needed, diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala index 26e8029365..1dc13754f9 100644 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/spark/executor/TaskMetrics.scala @@ -2,7 +2,7 @@ package spark.executor class TaskMetrics extends Serializable { /** - * host's name the task runs on + * Host's name the task runs on */ var hostname: String = _ diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index e281e5a8db..4336f2f36d 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -312,7 +312,8 @@ class DAGScheduler( handleExecutorLost(execId) case completion: CompletionEvent => - sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion))) + sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task, + completion.reason, completion.taskInfo, completion.taskMetrics))) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -323,8 +324,8 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobCancelled(job, - "SPARKCONTEXT_SHUTDOWN"))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, + JobFailed(error)))) } return true } @@ -527,7 +528,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobSuccess(job))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded))) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -668,10 +669,11 @@ class DAGScheduler( val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - job.listener.jobFailed(new SparkException("Job failed: " + reason)) + val error = new SparkException("Job failed: " + reason) + job.listener.jobFailed(error) activeJobs -= job resultStageToJob -= resultStage - sparkListeners.foreach(_.onJobEnd(SparkListenerJobFailed(job, failedStage))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 002c5826cb..760a0252b7 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -12,7 +12,7 @@ import spark._ import spark.executor.TaskMetrics import spark.scheduler.cluster.TaskInfo -// used to record runtime information for each job, including RDD graph +// Used to record runtime information for each job, including RDD graph // tasks' start/stop shuffle information and information from outside class JobLogger(val logDirName: String) extends SparkListener with Logging { @@ -49,21 +49,17 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { processStageSubmittedEvent(stage, taskSize) case StageCompleted(stageInfo) => processStageCompletedEvent(stageInfo) - case SparkListenerJobSuccess(job) => - processJobEndEvent(job) - case SparkListenerJobFailed(job, failedStage) => - processJobEndEvent(job, failedStage) - case SparkListenerJobCancelled(job, reason) => - processJobEndEvent(job, reason) - case SparkListenerTaskEnd(event) => - processTaskEndEvent(event) + case SparkListenerJobEnd(job, result) => + processJobEndEvent(job, result) + case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) => + processTaskEndEvent(task, reason, taskInfo, taskMetrics) case _ => } } } }.start() - //create a folder for log files, the folder's name is the creation time of the jobLogger + // Create a folder for log files, the folder's name is the creation time of the jobLogger protected def createLogDir() { val dir = new File(logDir + "/" + logDirName + "/") if (dir.exists()) { @@ -244,54 +240,50 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { eventQueue.put(taskEnd) } - protected def processTaskEndEvent(event: CompletionEvent) { + protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { var taskStatus = "" - event.task match { + task match { case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" } - event.reason match { + reason match { case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(event.task.stageId, taskStatus, event.taskInfo, event.taskMetrics) + recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics) case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + event.taskInfo.taskId + - " STAGE_ID=" + event.task.stageId - stageLogInfo(event.task.stageId, taskStatus) + taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + stageLogInfo(task.stageId, taskStatus) case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - taskStatus += " STATUS=FETCHFAILED TID=" + event.taskInfo.taskId + " STAGE_ID=" + - event.task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + + task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId - stageLogInfo(event.task.stageId, taskStatus) + stageLogInfo(task.stageId, taskStatus) case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + event.taskInfo.taskId + - " STAGE_ID=" + event.task.stageId + " INFO=" + message - stageLogInfo(event.task.stageId, taskStatus) + taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + " INFO=" + message + stageLogInfo(task.stageId, taskStatus) case _ => } } - override def onJobEnd(jobEnd: SparkListenerEvents) { + override def onJobEnd(jobEnd: SparkListenerJobEnd) { eventQueue.put(jobEnd) } - protected def processJobEndEvent(job: ActiveJob) { - val info = "JOB_ID=" + job.runId + " STATUS=SUCCESS" - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } - - protected def processJobEndEvent(job: ActiveJob, failedStage: Stage) { - val info = "JOB_ID=" + job.runId + " STATUS=FAILED REASON=STAGE_FAILED FAILED_STAGE_ID=" - + failedStage.id - jobLogInfo(job.runId, info) - closeLogWriter(job.runId) - } - protected def processJobEndEvent(job: ActiveJob, reason: String) { - var info = "JOB_ID=" + job.runId + " STATUS=CANCELLED REASON=" + reason - jobLogInfo(job.runId, info) + protected def processJobEndEvent(job: ActiveJob, reason: JobResult) { + var info = "JOB_ID=" + job.runId + reason match { + case JobSucceeded => info += " STATUS=SUCCESS" + case JobFailed(exception) => + info += " STATUS=FAILED REASON=" + exception.getMessage.split("\\s+").foreach(info += _ + "_") + case _ => + } + jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase) closeLogWriter(job.runId) } - + protected def recordJobProperties(jobID: Int, properties: Properties) { if(properties != null) { val annotation = properties.getProperty("spark.job.annotation", "") diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala index 9265261dc1..bac984b5c9 100644 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/spark/scheduler/SparkListener.scala @@ -3,52 +3,49 @@ package spark.scheduler import java.util.Properties import spark.scheduler.cluster.TaskInfo import spark.util.Distribution -import spark.{Utils, Logging, SparkContext, TaskEndReason} +import spark.{Logging, SparkContext, TaskEndReason, Utils} import spark.executor.TaskMetrics - sealed trait SparkListenerEvents case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents -case class SparkListenerTaskEnd(event: CompletionEvent) extends SparkListenerEvents +case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, + taskMetrics: TaskMetrics) extends SparkListenerEvents case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) extends SparkListenerEvents - -case class SparkListenerJobSuccess(job: ActiveJob) extends SparkListenerEvents - -case class SparkListenerJobFailed(job: ActiveJob, failedStage: Stage) extends SparkListenerEvents -case class SparkListenerJobCancelled(job: ActiveJob, reason: String) extends SparkListenerEvents +case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) + extends SparkListenerEvents trait SparkListener { /** - * called when a stage is completed, with information on the completed stage + * Called when a stage is completed, with information on the completed stage */ def onStageCompleted(stageCompleted: StageCompleted) { } /** - * called when a stage is submitted + * Called when a stage is submitted */ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } /** - * called when a task ends + * Called when a task ends */ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } /** - * called when a job starts + * Called when a job starts */ def onJobStart(jobStart: SparkListenerJobStart) { } /** - * called when a job ends + * Called when a job ends */ - def onJobEnd(jobEnd: SparkListenerEvents) { } + def onJobEnd(jobEnd: SparkListenerJobEnd) { } } diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala index a654bf3ffd..4000c4d520 100644 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -87,7 +87,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers var onStageCompletedCount = 0 var onStageSubmittedCount = 0 override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 - override def onJobEnd(jobEnd: SparkListenerEvents) = onJobEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 -- cgit v1.2.3 From aa7aa587beff22e2db50d2afadd95097856a299a Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:48:41 +0800 Subject: some format modification --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++--- core/src/main/scala/spark/scheduler/JobLogger.scala | 10 +++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 4336f2f36d..e412baa803 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -324,8 +324,7 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, - JobFailed(error)))) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } return true } @@ -671,9 +670,9 @@ class DAGScheduler( val job = resultStageToJob(resultStage) val error = new SparkException("Job failed: " + reason) job.listener.jobFailed(error) + sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) activeJobs -= job resultStageToJob -= resultStage - sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala index 760a0252b7..178bfaba3d 100644 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/spark/scheduler/JobLogger.scala @@ -70,7 +70,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // create a log file for one job, the file name is the jobID + // Create a log file for one job, the file name is the jobID protected def createLogWriter(jobID: Int) { try{ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) @@ -80,7 +80,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // close log file for one job, and clean the stage relationship in stageIDToJobID + // Close log file, and clean the stage relationship in stageIDToJobID protected def closeLogWriter(jobID: Int) = jobIDToPrintWriter.get(jobID).foreach { fileWriter => fileWriter.close() @@ -91,7 +91,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { jobIDToStages -= jobID } - // write log information to log file, withTime parameter controls whether to recored + // Write log information to log file, withTime parameter controls whether to recored // time stamp for the information protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { var writeInfo = info @@ -145,7 +145,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { } } - // generate indents and convert to String + // Generate indents and convert to String protected def indentString(indent: Int) = { val sb = new StringBuilder() for (i <- 1 to indent) { @@ -190,7 +190,7 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false) } - // record task metrics into job log files + // Record task metrics into job log files protected def recordTaskMetrics(stageID: Int, status: String, taskInfo: TaskInfo, taskMetrics: TaskMetrics) { val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + -- cgit v1.2.3 From 4b9862ac9cf2d00c5245e9a8b0fcb05b82030c98 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 17:55:32 +0800 Subject: small format modification --- core/src/main/scala/spark/Utils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1cfaee79b1..96d86647f8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,8 +522,8 @@ private object Utils extends Logging { execute(command, new File(".")) } - class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. -- cgit v1.2.3 From 2fc794a6c7f1b86e5c0103a9c82af2be7fafb347 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Fri, 21 Jun 2013 18:21:35 +0800 Subject: small modify in DAGScheduler --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index e412baa803..f7d60be5db 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -289,7 +289,6 @@ class DAGScheduler( val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions (allowLocal=" + allowLocal + ")") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -299,6 +298,7 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { + sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties))) idToActiveJob(runId) = job activeJobs += job resultStageToJob(finalStage) = job -- cgit v1.2.3 From 93a1643405d7c1a1fffe8210130341f34d64ea72 Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Fri, 21 Jun 2013 14:21:52 +0100 Subject: Allow other twitter authorizations than username/password --- .../src/main/scala/spark/streaming/StreamingContext.scala | 15 ++++++++++++++- .../spark/streaming/dstream/TwitterInputDStream.scala | 14 ++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b8b60aab43..f97e47ada0 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status +import twitter4j.auth.{Authorization, BasicAuthorization} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -372,8 +373,20 @@ class StreamingContext private ( password: String, filters: Seq[String] = Nil, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[Status] = twitterStream(new BasicAuthorization(username, password), filters, storageLevel) + + /** + * Create a input stream that returns tweets received from Twitter. + * @param twitterAuth Twitter4J authentication + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + twitterAuth: Authorization, + filters: Seq[String] = Nil, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[Status] = { - val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel) + val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index c697498862..0b01091a52 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -6,6 +6,7 @@ import storage.StorageLevel import twitter4j._ import twitter4j.auth.BasicAuthorization +import twitter4j.auth.Authorization /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -16,21 +17,19 @@ import twitter4j.auth.BasicAuthorization private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, - username: String, - password: String, + twitterAuth: Authorization, filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - + override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(username, password, filters, storageLevel) + new TwitterReceiver(twitterAuth, filters, storageLevel) } } private[streaming] class TwitterReceiver( - username: String, - password: String, + twitterAuth: Authorization, filters: Seq[String], storageLevel: StorageLevel ) extends NetworkReceiver[Status] { @@ -40,8 +39,7 @@ class TwitterReceiver( protected override def onStart() { blockGenerator.start() - twitterStream = new TwitterStreamFactory() - .getInstance(new BasicAuthorization(username, password)) + twitterStream = new TwitterStreamFactory().getInstance(twitterAuth) twitterStream.addListener(new StatusListener { def onStatus(status: Status) = { blockGenerator += status -- cgit v1.2.3 From 40afe0d2a5562738ef2ff37ed1d448ae2d0cc927 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Sun, 10 Mar 2013 13:54:46 -0700 Subject: Add Python timing instrumentation --- core/src/main/scala/spark/api/python/PythonRDD.scala | 12 ++++++++++++ python/pyspark/serializers.py | 4 ++++ python/pyspark/worker.py | 16 +++++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 807119ca8c..e9978d713f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -47,6 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( currentEnvVars.put(variable, value) } + val startTime = System.currentTimeMillis val proc = pb.start() val env = SparkEnv.get @@ -108,6 +109,17 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj + case -3 => + // Timing data from child + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + read case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 115cf28cc2..5a95144983 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -46,6 +46,10 @@ def read_long(stream): return struct.unpack("!q", length)[0] +def write_long(value, stream): + stream.write(struct.pack("!q", value)) + + def read_int(stream): length = stream.read(4) if length == "": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 812e7a9da5..4c33ae49dc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,6 +1,8 @@ """ Worker that receives input from Piped RDD. """ +import time +preboot_time = time.time() import os import sys import traceback @@ -12,7 +14,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -24,7 +26,16 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) +def report_times(preboot, boot, init, finish): + write_int(-3, old_stdout) + write_long(1000 * preboot, old_stdout) + write_long(1000 * boot, old_stdout) + write_long(1000 * init, old_stdout) + write_long(1000 * finish, old_stdout) + + def main(): + boot_time = time.time() split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir @@ -41,6 +52,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle + init_time = time.time() iterator = read_from_pickle_file(sys.stdin) try: for obj in func(split_index, iterator): @@ -49,6 +61,8 @@ def main(): write_int(-2, old_stdout) write_with_length(traceback.format_exc(), old_stdout) sys.exit(-1) + finish_time = time.time() + report_times(preboot_time, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): -- cgit v1.2.3 From c79a6078c34c207ad9f9910252f5849424828bf1 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Mon, 6 May 2013 16:34:30 -0700 Subject: Prefork Python worker processes --- core/src/main/scala/spark/SparkEnv.scala | 11 +++ .../main/scala/spark/api/python/PythonRDD.scala | 66 +++++-------- .../main/scala/spark/api/python/PythonWorker.scala | 89 +++++++++++++++++ python/pyspark/daemon.py | 109 +++++++++++++++++++++ python/pyspark/worker.py | 61 ++++++------ 5 files changed, 263 insertions(+), 73 deletions(-) create mode 100644 core/src/main/scala/spark/api/python/PythonWorker.scala create mode 100644 python/pyspark/daemon.py diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index be1a04d619..5691e24c32 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,5 +1,8 @@ package spark +import collection.mutable +import serializer.Serializer + import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} import akka.remote.RemoteActorRefProvider @@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils +import spark.api.python.PythonWorker /** @@ -37,6 +41,8 @@ class SparkEnv ( // If executorId is NOT found, return defaultHostPort var executorIdToHostPort: Option[(String, String) => String]) { + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() + def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -50,6 +56,11 @@ class SparkEnv ( actorSystem.awaitTermination() } + def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = { + synchronized { + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars)) + } + } def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { val env = SparkEnv.get diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index e9978d713f..e5acc54c01 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -2,10 +2,9 @@ package spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast @@ -16,7 +15,7 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], - envVars: java.util.Map[String, String], + envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], @@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest]( // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + def this(parent: RDD[T], command: String, envVars: JMap[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]]) = @@ -36,36 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - - for ((variable, value) <- envVars) { - currentEnvVars.put(variable, value) - } + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis - val proc = pb.start() + val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create val env = SparkEnv.get - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - val dOut = new DataOutputStream(proc.getOutputStream) + val out = new PrintWriter(worker.getOutputStream) + val dOut = new DataOutputStream(worker.getOutputStream) // Partition index dOut.writeInt(split.index) // sparkFilesDir @@ -89,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest]( } dOut.flush() out.flush() - proc.getOutputStream.close() + worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(proc.getInputStream) + val stream = new DataInputStream(worker.getInputStream) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj - _nextObj = read() + if (hasNext) { + // FIXME: can deadlock if worker is waiting for us to + // respond to current message (currently irrelevant because + // output is shutdown before we read any input) + _nextObj = read() + } obj } @@ -110,7 +96,7 @@ private[spark] class PythonRDD[T: ClassManifest]( stream.readFully(obj) obj case -3 => - // Timing data from child + // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() val finishTime = stream.readLong() @@ -127,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest]( stream.readFully(obj) throw new PythonException(new String(obj)) case -1 => - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() + // We've finished the data section of the output, but we can still + // read some accumulator updates; let's do that, breaking when we + // get a negative length record. + var len2 = stream.readInt + while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + len2 = stream.readInt } new Array[Byte](0) } } catch { case eof: EOFException => { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - new Array[Byte](0) + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } case e => throw e } @@ -171,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PairwiseRDD: unexpected value: " + x) + case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -227,7 +211,7 @@ private[spark] object PythonRDD { dOut.write(s) dOut.writeByte(Pickle.STOP) } else { - throw new Exception("Unexpected RDD type") + throw new SparkException("Unexpected RDD type") } } diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala new file mode 100644 index 0000000000..8ee3c6884f --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorker.scala @@ -0,0 +1,89 @@ +package spark.api.python + +import java.io.DataInputStream +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon + startDaemon + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + daemonPort = new DataInputStream(daemon.getInputStream).readInt + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + }.start() + } catch { + case e => { + stopDaemon + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy + } + + daemon = null + daemonPort = 0 + } + } +} diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py new file mode 100644 index 0000000000..642f30b2b9 --- /dev/null +++ b/python/pyspark/daemon.py @@ -0,0 +1,109 @@ +import os +import sys +import multiprocessing +from errno import EINTR, ECHILD +from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN +from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from pyspark.worker import main as worker_main +from pyspark.serializers import write_int + +try: + POOLSIZE = multiprocessing.cpu_count() +except NotImplementedError: + POOLSIZE = 4 + +should_exit = False + + +def worker(listen_sock): + # Redirect stdout to stderr + os.dup2(2, 1) + + # Manager sends SIGHUP to request termination of workers in the pool + def handle_sighup(signum, frame): + global should_exit + should_exit = True + signal(SIGHUP, handle_sighup) + + while not should_exit: + # Wait until a client arrives or we have to exit + sock = None + while not should_exit and sock is None: + try: + sock, addr = listen_sock.accept() + except EnvironmentError as err: + if err.errno != EINTR: + raise + + if sock is not None: + # Fork a child to handle the client + if os.fork() == 0: + # Leave the worker pool + signal(SIGHUP, SIG_DFL) + listen_sock.close() + # Handle the client then exit + sockfile = sock.makefile() + worker_main(sockfile, sockfile) + sockfile.close() + sock.close() + os._exit(0) + else: + sock.close() + + assert should_exit + os._exit(0) + + +def manager(): + # Create a new process group to corral our children + os.setpgid(0, 0) + + # Create a listening socket on the AF_INET loopback interface + listen_sock = socket(AF_INET, SOCK_STREAM) + listen_sock.bind(('127.0.0.1', 0)) + listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) + listen_host, listen_port = listen_sock.getsockname() + write_int(listen_port, sys.stdout) + + # Launch initial worker pool + for idx in range(POOLSIZE): + if os.fork() == 0: + worker(listen_sock) + raise RuntimeError("worker() unexpectedly returned") + listen_sock.close() + + def shutdown(): + global should_exit + os.kill(0, SIGHUP) + should_exit = True + + # Gracefully exit on SIGTERM, don't die on SIGHUP + signal(SIGTERM, lambda signum, frame: shutdown()) + signal(SIGHUP, SIG_IGN) + + # Cleanup zombie children + def handle_sigchld(signum, frame): + try: + pid, status = os.waitpid(0, os.WNOHANG) + if (pid, status) != (0, 0) and not should_exit: + raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) + except EnvironmentError as err: + if err.errno not in (ECHILD, EINTR): + raise + signal(SIGCHLD, handle_sigchld) + + # Initialization complete + sys.stdout.close() + while not should_exit: + try: + # Spark tells us to exit by closing stdin + if sys.stdin.read() == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + + +if __name__ == '__main__': + manager() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4c33ae49dc..94d612ea6e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,10 +1,9 @@ """ Worker that receives input from Piped RDD. """ -import time -preboot_time = time.time() import os import sys +import time import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the @@ -17,57 +16,55 @@ from pyspark.serializers import write_with_length, read_with_length, write_int, read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file -# Redirect stdout to stderr so that users must return values from functions. -old_stdout = os.fdopen(os.dup(1), 'w') -os.dup2(2, 1) - - -def load_obj(): - return load_pickle(standard_b64decode(sys.stdin.readline().strip())) +def load_obj(infile): + return load_pickle(standard_b64decode(infile.readline().strip())) -def report_times(preboot, boot, init, finish): - write_int(-3, old_stdout) - write_long(1000 * preboot, old_stdout) - write_long(1000 * boot, old_stdout) - write_long(1000 * init, old_stdout) - write_long(1000 * finish, old_stdout) +def report_times(outfile, boot, init, finish): + write_int(-3, outfile) + write_long(1000 * boot, outfile) + write_long(1000 * init, outfile) + write_long(1000 * finish, outfile) -def main(): +def main(infile, outfile): boot_time = time.time() - split_index = read_int(sys.stdin) - spark_files_dir = load_pickle(read_with_length(sys.stdin)) + split_index = read_int(infile) + spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) - num_broadcast_variables = read_int(sys.stdin) + num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): - bid = read_long(sys.stdin) - value = read_with_length(sys.stdin) + bid = read_long(infile) + value = read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) - func = load_obj() - bypassSerializer = load_obj() + func = load_obj(infile) + bypassSerializer = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(sys.stdin) + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + write_with_length(dumps(obj), outfile) except Exception as e: - write_int(-2, old_stdout) - write_with_length(traceback.format_exc(), old_stdout) - sys.exit(-1) + write_int(-2, outfile) + write_with_length(traceback.format_exc(), outfile) + raise finish_time = time.time() - report_times(preboot_time, boot_time, init_time, finish_time) + report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, old_stdout) + write_int(-1, outfile) for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), old_stdout) + write_with_length(dump_pickle((aid, accum._value)), outfile) + write_int(-1, outfile) if __name__ == '__main__': - main() + # Redirect stdout to stderr so that users must return values from functions. + old_stdout = os.fdopen(os.dup(1), 'w') + os.dup2(2, 1) + main(sys.stdin, old_stdout) -- cgit v1.2.3 From 62c4781400dd908c2fccdcebf0dc816ff0cb8ed4 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Fri, 10 May 2013 15:48:48 -0700 Subject: Add tests and fixes for Python daemon shutdown --- core/src/main/scala/spark/SparkEnv.scala | 1 + .../main/scala/spark/api/python/PythonWorker.scala | 4 ++ python/pyspark/daemon.py | 46 +++++++++++----------- python/pyspark/tests.py | 43 ++++++++++++++++++++ python/pyspark/worker.py | 2 + 5 files changed, 74 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5691e24c32..5b55d45212 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -44,6 +44,7 @@ class SparkEnv ( private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() def stop() { + pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() shuffleFetcher.stop() diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala index 8ee3c6884f..74c8c6d37a 100644 --- a/core/src/main/scala/spark/api/python/PythonWorker.scala +++ b/core/src/main/scala/spark/api/python/PythonWorker.scala @@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin } } + def stop() { + stopDaemon + } + private def startDaemon() { synchronized { // Is it already running? diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 642f30b2b9..ab9c19df57 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -12,7 +12,7 @@ try: except NotImplementedError: POOLSIZE = 4 -should_exit = False +should_exit = multiprocessing.Event() def worker(listen_sock): @@ -21,14 +21,13 @@ def worker(listen_sock): # Manager sends SIGHUP to request termination of workers in the pool def handle_sighup(signum, frame): - global should_exit - should_exit = True + assert should_exit.is_set() signal(SIGHUP, handle_sighup) - while not should_exit: + while not should_exit.is_set(): # Wait until a client arrives or we have to exit sock = None - while not should_exit and sock is None: + while not should_exit.is_set() and sock is None: try: sock, addr = listen_sock.accept() except EnvironmentError as err: @@ -36,8 +35,8 @@ def worker(listen_sock): raise if sock is not None: - # Fork a child to handle the client - if os.fork() == 0: + # Fork to handle the client + if os.fork() != 0: # Leave the worker pool signal(SIGHUP, SIG_DFL) listen_sock.close() @@ -50,7 +49,7 @@ def worker(listen_sock): else: sock.close() - assert should_exit + assert should_exit.is_set() os._exit(0) @@ -73,9 +72,7 @@ def manager(): listen_sock.close() def shutdown(): - global should_exit - os.kill(0, SIGHUP) - should_exit = True + should_exit.set() # Gracefully exit on SIGTERM, don't die on SIGHUP signal(SIGTERM, lambda signum, frame: shutdown()) @@ -85,8 +82,8 @@ def manager(): def handle_sigchld(signum, frame): try: pid, status = os.waitpid(0, os.WNOHANG) - if (pid, status) != (0, 0) and not should_exit: - raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) + if status != 0 and not should_exit.is_set(): + raise RuntimeError("worker crashed: %s, %s" % (pid, status)) except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise @@ -94,15 +91,20 @@ def manager(): # Initialization complete sys.stdout.close() - while not should_exit: - try: - # Spark tells us to exit by closing stdin - if sys.stdin.read() == '': - shutdown() - except EnvironmentError as err: - if err.errno != EINTR: - shutdown() - raise + try: + while not should_exit.is_set(): + try: + # Spark tells us to exit by closing stdin + if os.read(0, 512) == '': + shutdown() + except EnvironmentError as err: + if err.errno != EINTR: + shutdown() + raise + finally: + should_exit.set() + # Send SIGHUP to notify workers of shutdown + os.kill(0, SIGHUP) if __name__ == '__main__': diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6a1962d267..1e34d47365 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -12,6 +12,7 @@ import unittest from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME +from pyspark.serializers import read_int class PySparkTestCase(unittest.TestCase): @@ -117,5 +118,47 @@ class TestIO(PySparkTestCase): self.sc.parallelize([1]).foreach(func) +class TestDaemon(unittest.TestCase): + def connect(self, port): + from socket import socket, AF_INET, SOCK_STREAM + sock = socket(AF_INET, SOCK_STREAM) + sock.connect(('127.0.0.1', port)) + # send a split index of -1 to shutdown the worker + sock.send("\xFF\xFF\xFF\xFF") + sock.close() + return True + + def do_termination_test(self, terminator): + from subprocess import Popen, PIPE + from errno import ECONNREFUSED + + # start daemon + daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") + daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + + # read the port number + port = read_int(daemon.stdout) + + # daemon should accept connections + self.assertTrue(self.connect(port)) + + # request shutdown + terminator(daemon) + time.sleep(1) + + # daemon should no longer accept connections + with self.assertRaises(EnvironmentError) as trap: + self.connect(port) + self.assertEqual(trap.exception.errno, ECONNREFUSED) + + def test_termination_stdin(self): + """Ensure that daemon and workers terminate when stdin is closed.""" + self.do_termination_test(lambda daemon: daemon.stdin.close()) + + def test_termination_sigterm(self): + """Ensure that daemon and workers terminate on SIGTERM.""" + from signal import SIGTERM + self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 94d612ea6e..f76ee3c236 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish): def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) + if split_index == -1: # for unit tests + return spark_files_dir = load_pickle(read_with_length(infile)) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True -- cgit v1.2.3 From edb18ca928c988a713b9228bb74af1737f2b614b Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Mon, 13 May 2013 08:53:47 -0700 Subject: Rename PythonWorker to PythonWorkerFactory --- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../main/scala/spark/api/python/PythonRDD.scala | 2 +- .../main/scala/spark/api/python/PythonWorker.scala | 93 --------------------- .../spark/api/python/PythonWorkerFactory.scala | 95 ++++++++++++++++++++++ 4 files changed, 100 insertions(+), 98 deletions(-) delete mode 100644 core/src/main/scala/spark/api/python/PythonWorker.scala create mode 100644 core/src/main/scala/spark/api/python/PythonWorkerFactory.scala diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 5b55d45212..0a23c45658 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -12,7 +12,7 @@ import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.serializer.{Serializer, SerializerManager} import spark.util.AkkaUtils -import spark.api.python.PythonWorker +import spark.api.python.PythonWorkerFactory /** @@ -41,7 +41,7 @@ class SparkEnv ( // If executorId is NOT found, return defaultHostPort var executorIdToHostPort: Option[(String, String) => String]) { - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() + private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -57,9 +57,9 @@ class SparkEnv ( actorSystem.awaitTermination() } - def getPythonWorker(pythonExec: String, envVars: Map[String, String]): PythonWorker = { + def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorker(pythonExec, envVars)) + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index e5acc54c01..3c48071b3f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -38,8 +38,8 @@ private[spark] class PythonRDD[T: ClassManifest]( override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis - val worker = SparkEnv.get.getPythonWorker(pythonExec, envVars.toMap).create val env = SparkEnv.get + val worker = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala deleted file mode 100644 index 74c8c6d37a..0000000000 --- a/core/src/main/scala/spark/api/python/PythonWorker.scala +++ /dev/null @@ -1,93 +0,0 @@ -package spark.api.python - -import java.io.DataInputStream -import java.net.{Socket, SocketException, InetAddress} - -import scala.collection.JavaConversions._ - -import spark._ - -private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, String]) - extends Logging { - var daemon: Process = null - val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) - var daemonPort: Int = 0 - - def create(): Socket = { - synchronized { - // Start the daemon if it hasn't been started - startDaemon - - // Attempt to connect, restart and retry once if it fails - try { - new Socket(daemonHost, daemonPort) - } catch { - case exc: SocketException => { - logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon - startDaemon - new Socket(daemonHost, daemonPort) - } - case e => throw e - } - } - } - - def stop() { - stopDaemon - } - - private def startDaemon() { - synchronized { - // Is it already running? - if (daemon != null) { - return - } - - try { - // Create and start the daemon - val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) - val workerEnv = pb.environment() - workerEnv.putAll(envVars) - daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt - - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - val in = daemon.getErrorStream - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - }.start() - } catch { - case e => { - stopDaemon - throw e - } - } - - // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly - // detect our disappearance. - } - } - - private def stopDaemon() { - synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy - } - - daemon = null - daemonPort = 0 - } - } -} diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala new file mode 100644 index 0000000000..ebbd226b3e --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -0,0 +1,95 @@ +package spark.api.python + +import java.io.{DataInputStream, IOException} +import java.net.{Socket, SocketException, InetAddress} + +import scala.collection.JavaConversions._ + +import spark._ + +private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends Logging { + var daemon: Process = null + val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + var daemonPort: Int = 0 + + def create(): Socket = { + synchronized { + // Start the daemon if it hasn't been started + startDaemon + + // Attempt to connect, restart and retry once if it fails + try { + new Socket(daemonHost, daemonPort) + } catch { + case exc: SocketException => { + logWarning("Python daemon unexpectedly quit, attempting to restart") + stopDaemon + startDaemon + new Socket(daemonHost, daemonPort) + } + case e => throw e + } + } + } + + def stop() { + stopDaemon + } + + private def startDaemon() { + synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") + val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) + val workerEnv = pb.environment() + workerEnv.putAll(envVars) + daemon = pb.start() + daemonPort = new DataInputStream(daemon.getInputStream).readInt + + // Redirect the stderr to ours + new Thread("stderr reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + val in = daemon.getErrorStream + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() + } catch { + case e => { + stopDaemon + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def stopDaemon() { + synchronized { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy + } + + daemon = null + daemonPort = 0 + } + } +} -- cgit v1.2.3 From 7c5ff733ee1d3729b4b26f7c5542ca00c4d64139 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 23 May 2013 11:50:24 -0700 Subject: PySpark daemon: fix deadlock, improve error handling --- python/pyspark/daemon.py | 67 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index ab9c19df57..2b5e9b3581 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -1,6 +1,7 @@ import os import sys import multiprocessing +from ctypes import c_bool from errno import EINTR, ECHILD from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -12,7 +13,12 @@ try: except NotImplementedError: POOLSIZE = 4 -should_exit = multiprocessing.Event() +exit_flag = multiprocessing.Value(c_bool, False) + + +def should_exit(): + global exit_flag + return exit_flag.value def worker(listen_sock): @@ -20,14 +26,29 @@ def worker(listen_sock): os.dup2(2, 1) # Manager sends SIGHUP to request termination of workers in the pool - def handle_sighup(signum, frame): - assert should_exit.is_set() + def handle_sighup(*args): + assert should_exit() signal(SIGHUP, handle_sighup) - while not should_exit.is_set(): + # Cleanup zombie children + def handle_sigchld(*args): + pid = status = None + try: + while (pid, status) != (0, 0): + pid, status = os.waitpid(0, os.WNOHANG) + except EnvironmentError as err: + if err.errno == EINTR: + # retry + handle_sigchld() + elif err.errno != ECHILD: + raise + signal(SIGCHLD, handle_sigchld) + + # Handle clients + while not should_exit(): # Wait until a client arrives or we have to exit sock = None - while not should_exit.is_set() and sock is None: + while not should_exit() and sock is None: try: sock, addr = listen_sock.accept() except EnvironmentError as err: @@ -35,8 +56,10 @@ def worker(listen_sock): raise if sock is not None: - # Fork to handle the client - if os.fork() != 0: + # Fork a child to handle the client. + # The client is handled in the child so that the manager + # never receives SIGCHLD unless a worker crashes. + if os.fork() == 0: # Leave the worker pool signal(SIGHUP, SIG_DFL) listen_sock.close() @@ -49,8 +72,18 @@ def worker(listen_sock): else: sock.close() - assert should_exit.is_set() - os._exit(0) + +def launch_worker(listen_sock): + if os.fork() == 0: + try: + worker(listen_sock) + except Exception as err: + import traceback + traceback.print_exc() + os._exit(1) + else: + assert should_exit() + os._exit(0) def manager(): @@ -66,23 +99,22 @@ def manager(): # Launch initial worker pool for idx in range(POOLSIZE): - if os.fork() == 0: - worker(listen_sock) - raise RuntimeError("worker() unexpectedly returned") + launch_worker(listen_sock) listen_sock.close() def shutdown(): - should_exit.set() + global exit_flag + exit_flag.value = True # Gracefully exit on SIGTERM, don't die on SIGHUP signal(SIGTERM, lambda signum, frame: shutdown()) signal(SIGHUP, SIG_IGN) # Cleanup zombie children - def handle_sigchld(signum, frame): + def handle_sigchld(*args): try: pid, status = os.waitpid(0, os.WNOHANG) - if status != 0 and not should_exit.is_set(): + if status != 0 and not should_exit(): raise RuntimeError("worker crashed: %s, %s" % (pid, status)) except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): @@ -92,7 +124,7 @@ def manager(): # Initialization complete sys.stdout.close() try: - while not should_exit.is_set(): + while not should_exit(): try: # Spark tells us to exit by closing stdin if os.read(0, 512) == '': @@ -102,7 +134,8 @@ def manager(): shutdown() raise finally: - should_exit.set() + signal(SIGTERM, SIG_DFL) + exit_flag.value = True # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) -- cgit v1.2.3 From 1ba3c173034c37ef99fc312c84943d2ab8885670 Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Thu, 20 Jun 2013 12:49:10 -0400 Subject: use parens when calling method with side-effects --- core/src/main/scala/spark/SparkEnv.scala | 2 +- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++-- .../main/scala/spark/api/python/PythonWorkerFactory.scala | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 0a23c45658..7ccde2e818 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -59,7 +59,7 @@ class SparkEnv ( def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create + pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create() } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 3c48071b3f..63140cf37f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -116,12 +116,12 @@ private[spark] class PythonRDD[T: ClassManifest]( // We've finished the data section of the output, but we can still // read some accumulator updates; let's do that, breaking when we // get a negative length record. - var len2 = stream.readInt + var len2 = stream.readInt() while (len2 >= 0) { val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt + len2 = stream.readInt() } new Array[Byte](0) } diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala index ebbd226b3e..8844411d73 100644 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -16,7 +16,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { synchronized { // Start the daemon if it hasn't been started - startDaemon + startDaemon() // Attempt to connect, restart and retry once if it fails try { @@ -24,8 +24,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } catch { case exc: SocketException => { logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon - startDaemon + stopDaemon() + startDaemon() new Socket(daemonHost, daemonPort) } case e => throw e @@ -34,7 +34,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stop() { - stopDaemon + stopDaemon() } private def startDaemon() { @@ -51,7 +51,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars) daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt + daemonPort = new DataInputStream(daemon.getInputStream).readInt() // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { @@ -71,7 +71,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String }.start() } catch { case e => { - stopDaemon + stopDaemon() throw e } } @@ -85,7 +85,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String synchronized { // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { - daemon.destroy + daemon.destroy() } daemon = null -- cgit v1.2.3 From c75bed0eebb1f937db02eb98deecd380724f747d Mon Sep 17 00:00:00 2001 From: Jey Kottalam Date: Fri, 21 Jun 2013 12:13:48 -0400 Subject: Fix reporting of PySpark exceptions --- python/pyspark/daemon.py | 22 ++++++++++++++++++---- python/pyspark/worker.py | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 2b5e9b3581..78a2da1e18 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -21,6 +21,15 @@ def should_exit(): return exit_flag.value +def compute_real_exit_code(exit_code): + # SystemExit's code can be integer or string, but os._exit only accepts integers + import numbers + if isinstance(exit_code, numbers.Integral): + return exit_code + else: + return 1 + + def worker(listen_sock): # Redirect stdout to stderr os.dup2(2, 1) @@ -65,10 +74,15 @@ def worker(listen_sock): listen_sock.close() # Handle the client then exit sockfile = sock.makefile() - worker_main(sockfile, sockfile) - sockfile.close() - sock.close() - os._exit(0) + exit_code = 0 + try: + worker_main(sockfile, sockfile) + except SystemExit as exc: + exit_code = exc.code + finally: + sockfile.close() + sock.close() + os._exit(compute_real_exit_code(exit_code)) else: sock.close() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f76ee3c236..379bbfd4c2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -55,7 +55,7 @@ def main(infile, outfile): except Exception as e: write_int(-2, outfile) write_with_length(traceback.format_exc(), outfile) - raise + sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output -- cgit v1.2.3 From b350f34703d4f29bbd0e603df852f7aae230b2a2 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 07:48:20 -0700 Subject: Increase memory for tests to prevent a crash on JDK 7 --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 24c8b734d0..faf6e2ae8e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -54,7 +54,7 @@ object SparkBuild extends Build { // Fork new JVMs for tests and set Java options for those fork := true, - javaOptions += "-Xmx2g", + javaOptions += "-Xmx2500m", // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), -- cgit v1.2.3 From d92d3f7938dec954ea31de232f50cafd4b644065 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 10:24:19 -0700 Subject: Fix resolution of example code with Maven builds --- run | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/run b/run index c0065c53f1..e656e38ccf 100755 --- a/run +++ b/run @@ -132,10 +132,14 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" fi CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" +# Add the shaded JAR for Maven builds if [ -e $REPL_BIN_DIR/target ]; then for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH="$CLASSPATH:$jar" done + # The shaded JAR doesn't contain examples, so include those separately + EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` + CLASSPATH+=":$EXAMPLES_JAR" fi CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do @@ -148,9 +152,9 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; # Use the JAR from the SBT build export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar` fi -if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then +if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then # Use the JAR from the Maven build - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar` + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` fi # Add hadoop conf dir - else FileSystem.*, etc fail ! -- cgit v1.2.3 From b5df1cd668e45fd0cc22c1666136d05548cae3e9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 17:12:39 -0700 Subject: ADD_JARS environment variable for spark-shell --- docs/scala-programming-guide.md | 10 ++++++++-- repl/src/main/scala/spark/repl/SparkILoop.scala | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index b0da130fcb..e9cf9ef36f 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars]) The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later. -In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use +In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use {% highlight bash %} $ MASTER=local[4] ./spark-shell {% endhighlight %} +Or, to also add `code.jar` to its classpath, use: + +{% highlight bash %} +$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell +{% endhighlight %} + ### Master URLs The master URL passed to Spark can be in one of the following formats: @@ -78,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio * `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them). * `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies. -If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed. +If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed. # Resilient Distributed Datasets (RDDs) diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala index 23556dbc8f..86eed090d0 100644 --- a/repl/src/main/scala/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/spark/repl/SparkILoop.scala @@ -822,7 +822,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: spark.repl.Main.interp.out.println("Spark context available as sc."); spark.repl.Main.interp.out.flush(); """) - command("import spark.SparkContext._"); + command("import spark.SparkContext._") } echo("Type in expressions to have them evaluated.") echo("Type :help for more information.") @@ -838,7 +838,8 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (prop != null) prop else "local" } } - sparkContext = new SparkContext(master, "Spark shell") + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) sparkContext } @@ -850,6 +851,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: printWelcome() echo("Initializing interpreter...") + // Add JARS specified in Spark's ADD_JARS variable to classpath + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + jars.foreach(settings.classpath.append(_)) + this.settings = settings createInterpreter() -- cgit v1.2.3 From 0e0f9d3069039f03bbf5eefe3b0637c89fddf0f1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 22 Jun 2013 17:44:04 -0700 Subject: Fix search path for REPL class loader to really find added JARs --- core/src/main/scala/spark/executor/Executor.scala | 38 +++++++++++++---------- repl/src/main/scala/spark/repl/SparkILoop.scala | 4 ++- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 8bebfafce4..2bf55ea9a9 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -42,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert // Create our ClassLoader and set it on this thread private val urlClassLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(urlClassLoader) + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. @@ -88,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) @@ -153,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val urls = currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL }.toArray - loader = new URLClassLoader(urls, loader) + new ExecutorURLClassLoader(urls, loader) + } - // If the REPL is in use, add another ClassLoader that will read - // new classes defined by the REPL as the user types code + /** + * If the REPL is in use, add another ClassLoader that will read + * new classes defined by the REPL as the user types code + */ + private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { val classUri = System.getProperty("spark.repl.class.uri") if (classUri != null) { logInfo("Using REPL class URI: " + classUri) - loader = { - try { - val klass = Class.forName("spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - constructor.newInstance(classUri, loader) - } catch { - case _: ClassNotFoundException => loader - } + try { + val klass = Class.forName("spark.repl.ExecutorClassLoader") + .asInstanceOf[Class[_ <: ClassLoader]] + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) + return constructor.newInstance(classUri, parent) + } catch { + case _: ClassNotFoundException => + logError("Could not find spark.repl.ExecutorClassLoader on classpath!") + System.exit(1) + null } + } else { + return parent } - - return new ExecutorURLClassLoader(Array(), loader) } /** diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala index 86eed090d0..59f9d05683 100644 --- a/repl/src/main/scala/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/spark/repl/SparkILoop.scala @@ -838,7 +838,9 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (prop != null) prop else "local" } } - val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) + .getOrElse(new Array[String](0)) + .map(new java.io.File(_).getAbsolutePath) sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) sparkContext } -- cgit v1.2.3 From 78ffe164b33c6b11a2e511442605acd2f795a1b5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 23 Jun 2013 10:07:16 -0700 Subject: Clone the zero value for each key in foldByKey The old version reused the object within each task, leading to overwriting of the object when a mutable type is used, which is expected to be common in fold. Conflicts: core/src/test/scala/spark/ShuffleSuite.scala --- core/src/main/scala/spark/PairRDDFunctions.scala | 15 ++++++++++++--- core/src/test/scala/spark/ShuffleSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index fa4bbfc76f..7630fe7803 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,5 +1,6 @@ package spark +import java.nio.ByteBuffer import java.util.{Date, HashMap => JHashMap} import java.text.SimpleDateFormat @@ -64,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( throw new SparkException("Default partitioner cannot partition array keys.") } } - val aggregator = - new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { self.mapPartitions(aggregator.combineValuesByKey(_), true) } else if (mapSideCombine) { @@ -97,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner) + // Serialize the zero value to a byte array so that we can get a new clone of it on each key + val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) + val zeroArray = new Array[Byte](zeroBuffer.limit) + zeroBuffer.get(zeroArray) + + // When deserializing, use a lazy val to create just one instance of the serializer per task + lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() + def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + + combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } /** diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 1916885a73..0c1ec29f96 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -392,6 +392,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(nonEmptyBlocks.size <= 4) } + test("foldByKey") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } } object ShuffleSuite { -- cgit v1.2.3 From 8955787a596216a35ad4ec52b57331aa40444bef Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Mon, 24 Jun 2013 09:15:17 +0100 Subject: Twitter API v1 is retired - username/password auth no longer possible --- .../main/scala/spark/streaming/StreamingContext.scala | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index f97e47ada0..05be6bd58a 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import twitter4j.Status -import twitter4j.auth.{Authorization, BasicAuthorization} +import twitter4j.auth.Authorization /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -361,20 +361,6 @@ class StreamingContext private ( fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } - /** - * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password - * @param filters Set of filter strings to get only those tweets that match them - * @param storageLevel Storage level to use for storing the received objects - */ - def twitterStream( - username: String, - password: String, - filters: Seq[String] = Nil, - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): DStream[Status] = twitterStream(new BasicAuthorization(username, password), filters, storageLevel) - /** * Create a input stream that returns tweets received from Twitter. * @param twitterAuth Twitter4J authentication -- cgit v1.2.3 From 48c7e373c62b2e8cf48157ceb0d92c38c3a40652 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 24 Jun 2013 23:11:04 -0700 Subject: Minor formatting fixes --- .../src/main/scala/spark/streaming/DStream.scala | 9 +++++-- .../scala/spark/streaming/StreamingContext.scala | 29 +++++++++++++--------- .../streaming/api/java/JavaStreamingContext.scala | 15 +++++++---- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index e125310861..9be7926a4a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] ( * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2) + def count(): DStream[Long] = { + this.map(_ => (null, 1L)) + .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) + .reduceByKey(_ + _) + .map(_._2) + } /** * Return a new DStream in which each RDD contains the counts of each distinct value in @@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] ( * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: RDD[T] => Unit) { - foreach((r: RDD[T], t: Time) => foreachFunc(r)) + this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) } /** diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 2c6326943d..03d2907323 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -171,10 +171,11 @@ class StreamingContext private ( * should be same. */ def actorStream[T: ClassManifest]( - props: Props, - name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, - supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = { + props: Props, + name: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -182,9 +183,10 @@ class StreamingContext private ( * Create an input stream that receives messages pushed by a zeromq publisher. * @param publisherUrl Url of remote zeromq publisher * @param subscribe topic to subscribe to - * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence - * of byte thus it needs the converter(which might be deserializer of bytes) - * to translate from sequence of sequence of bytes, where sequence refer to a frame + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic + * and each frame has sequence of byte thus it needs the converter + * (which might be deserializer of bytes) to translate from sequence + * of sequence of bytes, where sequence refer to a frame * and sub sequence refer to its payload. * @param storageLevel RDD storage level. Defaults to memory-only. */ @@ -204,7 +206,7 @@ class StreamingContext private ( * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ @@ -214,15 +216,17 @@ class StreamingContext private ( topics: Map[String, Int], storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[String] = { - val kafkaParams = Map[String, String]("zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000"); + val kafkaParams = Map[String, String]( + "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000") kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel) } /** * Create an input stream that pulls messages from a Kafka Broker. - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed - * in its own thread. + * in its own thread. * @param storageLevel Storage level to use for storing the received objects */ def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest]( @@ -395,7 +399,8 @@ class StreamingContext private ( * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. + * Set as null if no RDD should be returned when empty * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index b35d9032f1..fd5e06b50f 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -75,7 +75,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), StorageLevel.MEMORY_ONLY_SER_2) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + StorageLevel.MEMORY_ONLY_SER_2) } /** @@ -83,8 +84,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only - * in its own thread. + * */ def kafkaStream( zkQuorum: String, @@ -94,14 +96,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { : JavaDStream[String] = { implicit val cmt: ClassManifest[String] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]] - ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** * Create an input stream that pulls messages form a Kafka Broker. * @param typeClass Type of RDD * @param decoderClass Type of kafka decoder - * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html + * @param kafkaParams Map of kafka configuration paramaters. + * See: http://kafka.apache.org/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. Defaults to memory-only @@ -113,7 +117,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { topics: JMap[String, JInt], storageLevel: StorageLevel) : JavaDStream[T] = { - implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cmt: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]] ssc.kafkaStream[T, D]( kafkaParams.toMap, -- cgit v1.2.3 From 7680ce0bd65fc44716c5bc03d5909a3ddbd43501 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 16:11:44 -0400 Subject: Fixed deprecated use of expect in SizeEstimatorSuite --- core/src/test/scala/spark/SizeEstimatorSuite.scala | 72 +++++++++++----------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index e235ef2f67..b5c8525f91 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -35,7 +35,7 @@ class SizeEstimatorSuite var oldOops: String = _ override def beforeAll() { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") oldOops = System.setProperty("spark.test.useCompressedOops", "true") } @@ -46,54 +46,54 @@ class SizeEstimatorSuite } test("simple classes") { - expect(16)(SizeEstimator.estimate(new DummyClass1)) - expect(16)(SizeEstimator.estimate(new DummyClass2)) - expect(24)(SizeEstimator.estimate(new DummyClass3)) - expect(24)(SizeEstimator.estimate(new DummyClass4(null))) - expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) + assert(SizeEstimator.estimate(new DummyClass1) === 16) + assert(SizeEstimator.estimate(new DummyClass2) === 16) + assert(SizeEstimator.estimate(new DummyClass3) === 24) + assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) + assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) } // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - expect(40)(SizeEstimator.estimate(DummyString(""))) - expect(48)(SizeEstimator.estimate(DummyString("a"))) - expect(48)(SizeEstimator.estimate(DummyString("ab"))) - expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) } test("primitive arrays") { - expect(32)(SizeEstimator.estimate(new Array[Byte](10))) - expect(40)(SizeEstimator.estimate(new Array[Char](10))) - expect(40)(SizeEstimator.estimate(new Array[Short](10))) - expect(56)(SizeEstimator.estimate(new Array[Int](10))) - expect(96)(SizeEstimator.estimate(new Array[Long](10))) - expect(56)(SizeEstimator.estimate(new Array[Float](10))) - expect(96)(SizeEstimator.estimate(new Array[Double](10))) - expect(4016)(SizeEstimator.estimate(new Array[Int](1000))) - expect(8016)(SizeEstimator.estimate(new Array[Long](1000))) + assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) + assert(SizeEstimator.estimate(new Array[Char](10)) === 40) + assert(SizeEstimator.estimate(new Array[Short](10)) === 40) + assert(SizeEstimator.estimate(new Array[Int](10)) === 56) + assert(SizeEstimator.estimate(new Array[Long](10)) === 96) + assert(SizeEstimator.estimate(new Array[Float](10)) === 56) + assert(SizeEstimator.estimate(new Array[Double](10)) === 96) + assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) + assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) } test("object arrays") { // Arrays containing nulls should just have one pointer per element - expect(56)(SizeEstimator.estimate(new Array[String](10))) - expect(56)(SizeEstimator.estimate(new Array[AnyRef](10))) + assert(SizeEstimator.estimate(new Array[String](10)) === 56) + assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) // For object arrays with non-null elements, each object should take one pointer plus // however many bytes that class takes. (Note that Array.fill calls the code in its // second parameter separately for each object, so we get distinct objects.) - expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1))) - expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2))) - expect(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3))) - expect(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2))) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) + assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) // Past size 100, our samples 100 elements, but we should still get the right size. - expect(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) + assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - expect(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object - expect(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object + assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object + assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first @@ -111,10 +111,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - expect(40)(SizeEstimator.estimate(DummyString(""))) - expect(48)(SizeEstimator.estimate(DummyString("a"))) - expect(48)(SizeEstimator.estimate(DummyString("ab"))) - expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) resetOrClear("os.arch", arch) } @@ -128,10 +128,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - expect(56)(SizeEstimator.estimate(DummyString(""))) - expect(64)(SizeEstimator.estimate(DummyString("a"))) - expect(64)(SizeEstimator.estimate(DummyString("ab"))) - expect(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) + assert(SizeEstimator.estimate(DummyString("")) === 56) + assert(SizeEstimator.estimate(DummyString("a")) === 64) + assert(SizeEstimator.estimate(DummyString("ab")) === 64) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) -- cgit v1.2.3 From 15b00914c53f1f4f00a3313968f68a8f032e7cb7 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 17:17:27 -0400 Subject: Some fixes to the launch-java-directly change: - Split SPARK_JAVA_OPTS into multiple command-line arguments if it contains spaces; this splitting follows quoting rules in bash - Add the Scala JARs to the classpath if they're not in the CLASSPATH variable because the ExecutorRunner is launched with "scala" (this can happen when using local-cluster URLs in spark-shell) --- core/src/main/scala/spark/Utils.scala | 65 +++++++++++++++++++++- .../scala/spark/deploy/worker/ExecutorRunner.scala | 51 +++++++++++------ core/src/test/scala/spark/UtilsSuite.scala | 53 +++++++++++++----- 3 files changed, 138 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index f3621c6bee..bdc1494cc9 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,7 +522,7 @@ private object Utils extends Logging { execute(command, new File(".")) } - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, val firstUserLine: Int, val firstUserClass: String) /** * When called inside a class in the spark package, returns the name of the user code class @@ -610,4 +610,67 @@ private object Utils extends Logging { } return false } + + def isSpace(c: Char): Boolean = { + " \t\r\n".indexOf(c) != -1 + } + + /** + * Split a string of potentially quoted arguments from the command line the way that a shell + * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', + * then it would be parsed as three arguments: 'a', 'b c' and 'd'. + */ + def splitCommandString(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var inWord = false + var inSingleQuote = false + var inDoubleQuote = false + var curWord = new StringBuilder + def endWord() { + buf += curWord.toString + curWord.clear() + } + var i = 0 + while (i < s.length) { + var nextChar = s.charAt(i) + if (inDoubleQuote) { + if (nextChar == '"') { + inDoubleQuote = false + } else if (nextChar == '\\') { + if (i < s.length - 1) { + // Append the next character directly, because only " and \ may be escaped in + // double quotes after the shell's own expansion + curWord.append(s.charAt(i + 1)) + i += 1 + } + } else { + curWord.append(nextChar) + } + } else if (inSingleQuote) { + if (nextChar == '\'') { + inSingleQuote = false + } else { + curWord.append(nextChar) + } + // Backslashes are not treated specially in single quotes + } else if (nextChar == '"') { + inWord = true + inDoubleQuote = true + } else if (nextChar == '\'') { + inWord = true + inSingleQuote = true + } else if (!isSpace(nextChar)) { + curWord.append(nextChar) + inWord = true + } else if (inWord && isSpace(nextChar)) { + endWord() + inWord = false + } + i += 1 + } + if (inWord || inDoubleQuote || inSingleQuote) { + endWord() + } + return buf + } } diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 4d31657d9e..db580e39ab 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -40,7 +40,7 @@ private[spark] class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { + shutdownHook = new Thread() { override def run() { if (process != null) { logInfo("Shutdown hook killing child process.") @@ -87,25 +87,43 @@ private[spark] class ExecutorRunner( Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ command.arguments.map(substituteVariables) } - - /* - * Attention: this must always be aligned with the environment variables in the run scripts and the - * way the JAVA_OPTS are assembled there. + + /** + * Attention: this must always be aligned with the environment variables in the run scripts and + * the way the JAVA_OPTS are assembled there. */ def buildJavaOpts(): Seq[String] = { - val _javaLibPath = if (System.getenv("SPARK_LIBRARY_PATH") == null) { - "" + val libraryOpts = if (System.getenv("SPARK_LIBRARY_PATH") == null) { + Nil + } else { + List("-Djava.library.path=" + System.getenv("SPARK_LIBRARY_PATH")) + } + + val userOpts = if (System.getenv("SPARK_JAVA_OPTS") == null) { + Nil } else { - "-Djava.library.path=" + System.getenv("SPARK_LIBRARY_PATH") + Utils.splitCommandString(System.getenv("SPARK_JAVA_OPTS")) } - - Seq("-cp", - System.getenv("CLASSPATH"), - System.getenv("SPARK_JAVA_OPTS"), - _javaLibPath, - "-Xms" + memory.toString + "M", - "-Xmx" + memory.toString + "M") - .filter(_ != null) + + val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") + + var classPath = System.getenv("CLASSPATH") + if (System.getenv("SPARK_LAUNCH_WITH_SCALA") == "1") { + // Add the Scala library JARs to the classpath; this is needed when the ExecutorRunner + // was launched with "scala" as the runner (e.g. in spark-shell in local-cluster mode) + // and the Scala libraries won't be in the CLASSPATH environment variable by defalt. + if (System.getenv("SCALA_LIBRARY_PATH") == null && System.getenv("SCALA_HOME") == null) { + logError("Cloud not launch executors: neither SCALA_LIBRARY_PATH nor SCALA_HOME are set") + System.exit(1) + } + val scalaLib = Option(System.getenv("SCALA_LIBRARY_PATH")).getOrElse( + System.getenv("SCALA_HOME") + "/lib") + classPath += ":" + scalaLib + "/scala-library.jar" + + ":" + scalaLib + "/scala-compiler.jar" + + ":" + scalaLib + "/jline.jar" + } + + Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ @@ -136,6 +154,7 @@ private[spark] class ExecutorRunner( // Launch the process val command = buildCommandSeq() + println("COMMAND: " + command.mkString(" ")) val builder = new ProcessBuilder(command: _*).directory(executorDir) val env = builder.environment() for ((key, value) <- appDesc.command.environment) { diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala index ed4701574f..4a113e16bf 100644 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite { assert(os.toByteArray.toList.equals(bytes.toList)) } - test("memoryStringToMb"){ - assert(Utils.memoryStringToMb("1") == 0) - assert(Utils.memoryStringToMb("1048575") == 0) - assert(Utils.memoryStringToMb("3145728") == 3) + test("memoryStringToMb") { + assert(Utils.memoryStringToMb("1") === 0) + assert(Utils.memoryStringToMb("1048575") === 0) + assert(Utils.memoryStringToMb("3145728") === 3) - assert(Utils.memoryStringToMb("1024k") == 1) - assert(Utils.memoryStringToMb("5000k") == 4) - assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K")) + assert(Utils.memoryStringToMb("1024k") === 1) + assert(Utils.memoryStringToMb("5000k") === 4) + assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) - assert(Utils.memoryStringToMb("1024m") == 1024) - assert(Utils.memoryStringToMb("5000m") == 5000) - assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M")) + assert(Utils.memoryStringToMb("1024m") === 1024) + assert(Utils.memoryStringToMb("5000m") === 5000) + assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) - assert(Utils.memoryStringToMb("2g") == 2048) - assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G")) + assert(Utils.memoryStringToMb("2g") === 2048) + assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) - assert(Utils.memoryStringToMb("2t") == 2097152) - assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T")) + assert(Utils.memoryStringToMb("2t") === 2097152) + assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) + } + + test("splitCommandString") { + assert(Utils.splitCommandString("") === Seq()) + assert(Utils.splitCommandString("a") === Seq("a")) + assert(Utils.splitCommandString("aaa") === Seq("aaa")) + assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) + assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) + assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) + assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("'b c'") === Seq("b c")) + assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) + assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) + assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) + assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) + assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) + assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) + assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) + assert(Utils.splitCommandString("'a'b") === Seq("ab")) + assert(Utils.splitCommandString("'a''b'") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) + assert(Utils.splitCommandString("''") === Seq("")) + assert(Utils.splitCommandString("\"\"") === Seq("")) } } -- cgit v1.2.3 From 366572edcab87701fd795ca0142ac9829b312d36 Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Tue, 25 Jun 2013 22:59:34 +0100 Subject: Include a default OAuth implementation, and update examples and JavaStreamingContext --- .../streaming/examples/TwitterAlgebirdCMS.scala | 2 +- .../streaming/examples/TwitterAlgebirdHLL.scala | 2 +- .../streaming/examples/TwitterPopularTags.scala | 2 +- .../scala/spark/streaming/StreamingContext.scala | 2 +- .../streaming/api/java/JavaStreamingContext.scala | 69 +++++++++++++++------- .../streaming/dstream/TwitterInputDStream.scala | 32 +++++++++- 6 files changed, 81 insertions(+), 28 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala index a9642100e3..548190309e 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -45,7 +45,7 @@ object TwitterAlgebirdCMS { val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala index f3288bfb85..5a86c6318d 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -34,7 +34,7 @@ object TwitterAlgebirdHLL { val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) val users = stream.map(status => status.getUser.getId) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala index 9d4494c6f2..076c3878c8 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala @@ -23,7 +23,7 @@ object TwitterPopularTags { val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(username, password, filters) + val stream = ssc.twitterStream(None, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 05be6bd58a..0f36504c0d 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -368,7 +368,7 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects */ def twitterStream( - twitterAuth: Authorization, + twitterAuth: Option[Authorization] = None, filters: Seq[String] = Nil, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[Status] = { diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 3d149a742c..85390ef57e 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -4,23 +4,18 @@ import spark.streaming._ import receivers.{ActorReceiver, ReceiverSupervisorStrategy} import spark.streaming.dstream._ import spark.storage.StorageLevel - import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import spark.api.java.{JavaSparkContext, JavaRDD} - import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - import twitter4j.Status - import akka.actor.Props import akka.actor.SupervisorStrategy import akka.zeromq.Subscribe - import scala.collection.JavaConversions._ - import java.lang.{Long => JLong, Integer => JInt} import java.io.InputStream import java.util.{Map => JMap} +import twitter4j.auth.Authorization /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -315,46 +310,78 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ def twitterStream( - username: String, - password: String, + twitterAuth: Authorization, filters: Array[String], storageLevel: StorageLevel ): JavaDStream[Status] = { - ssc.twitterStream(username, password, filters, storageLevel) + ssc.twitterStream(Some(twitterAuth), filters, storageLevel) + } + + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + filters: Array[String], + storageLevel: StorageLevel + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters, storageLevel) } /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization * @param filters Set of filter strings to get only those tweets that match them */ def twitterStream( - username: String, - password: String, + twitterAuth: Authorization, filters: Array[String] ): JavaDStream[Status] = { - ssc.twitterStream(username, password, filters) + ssc.twitterStream(Some(twitterAuth), filters) + } + + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + * @param filters Set of filter strings to get only those tweets that match them + */ + def twitterStream( + filters: Array[String] + ): JavaDStream[Status] = { + ssc.twitterStream(None, filters) } /** * Create a input stream that returns tweets received from Twitter. - * @param username Twitter username - * @param password Twitter password + * @param twitterAuth Twitter4J Authorization */ def twitterStream( - username: String, - password: String + twitterAuth: Authorization ): JavaDStream[Status] = { - ssc.twitterStream(username, password) + ssc.twitterStream(Some(twitterAuth)) } + /** + * Create a input stream that returns tweets received from Twitter using + * java.util.Preferences to store OAuth token. OAuth key and secret should + * be provided using system properties twitter4j.oauth.consumerKey and + * twitter4j.oauth.consumerSecret + */ + def twitterStream(): JavaDStream[Status] = { + ssc.twitterStream() + } + /** * Create an input stream with any arbitrary user implemented actor receiver. * @param props Props object defining creation of the actor diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index 0b01091a52..e0c654d385 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -3,27 +3,53 @@ package spark.streaming.dstream import spark._ import spark.streaming._ import storage.StorageLevel - import twitter4j._ import twitter4j.auth.BasicAuthorization import twitter4j.auth.Authorization +import java.util.prefs.Preferences +import twitter4j.conf.PropertyConfiguration +import twitter4j.auth.OAuthAuthorization +import twitter4j.auth.AccessToken /* A stream of Twitter statuses, potentially filtered by one or more keywords. * * @constructor create a new Twitter stream using the supplied username and password to authenticate. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. +* +* Includes a simple implementation of OAuth using consumer key and secret provided using system +* properties twitter4j.oauth.consumerKey and twitter4j.oauth.consumerSecret */ private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, - twitterAuth: Authorization, + twitterAuth: Option[Authorization], filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { + lazy val createOAuthAuthorization: Authorization = { + val userRoot = Preferences.userRoot(); + val token = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN, null)) + val tokenSecret = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, null)) + val oAuth = new OAuthAuthorization(new PropertyConfiguration(System.getProperties())) + if (token.isEmpty || tokenSecret.isEmpty) { + val requestToken = oAuth.getOAuthRequestToken() + println("Authorize application using URL: "+requestToken.getAuthorizationURL()) + println("Enter PIN: ") + val pin = Console.readLine + val accessToken = if (pin.length() > 0) oAuth.getOAuthAccessToken(requestToken, pin) else oAuth.getOAuthAccessToken() + userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN, accessToken.getToken()) + userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, accessToken.getTokenSecret()) + userRoot.flush() + } else { + oAuth.setOAuthAccessToken(new AccessToken(token.get, tokenSecret.get)); + } + oAuth + } + override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(twitterAuth, filters, storageLevel) + new TwitterReceiver(if (twitterAuth.isEmpty) createOAuthAuthorization else twitterAuth.get, filters, storageLevel) } } -- cgit v1.2.3 From 176193b1e8acdbe2f1cfaed16b8f42f89e226f79 Mon Sep 17 00:00:00 2001 From: James Phillpotts Date: Tue, 25 Jun 2013 23:06:15 +0100 Subject: Fix usage and parameter extraction --- .../main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala | 7 +++---- .../main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala | 7 +++---- .../main/scala/spark/streaming/examples/TwitterPopularTags.scala | 7 +++---- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala index 548190309e..528778ed72 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -26,8 +26,8 @@ import spark.SparkContext._ */ object TwitterAlgebirdCMS { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: TwitterAlgebirdCMS " + + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdCMS " + " [filter1] [filter2] ... [filter n]") System.exit(1) } @@ -40,8 +40,7 @@ object TwitterAlgebirdCMS { // K highest frequency elements to take val TOPK = 10 - val Array(master, username, password) = args.slice(0, 3) - val filters = args.slice(3, args.length) + val (master, filters) = (args.head, args.tail) val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala index 5a86c6318d..896e9fd8af 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -21,16 +21,15 @@ import spark.streaming.dstream.TwitterInputDStream */ object TwitterAlgebirdHLL { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: TwitterAlgebirdHLL " + + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdHLL " + " [filter1] [filter2] ... [filter n]") System.exit(1) } /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ val BIT_SIZE = 12 - val Array(master, username, password) = args.slice(0, 3) - val filters = args.slice(3, args.length) + val (master, filters) = (args.head, args.tail) val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala index 076c3878c8..65f0b6d352 100644 --- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala @@ -12,14 +12,13 @@ import spark.SparkContext._ */ object TwitterPopularTags { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: TwitterPopularTags " + + if (args.length < 1) { + System.err.println("Usage: TwitterPopularTags " + " [filter1] [filter2] ... [filter n]") System.exit(1) } - val Array(master, username, password) = args.slice(0, 3) - val filters = args.slice(3, args.length) + val (master, filters) = (args.head, args.tail) val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) -- cgit v1.2.3 From 6c8d1b2ca618a1a17566ede46821c0807a1b11f5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 18:21:00 -0400 Subject: Fix computation of classpath when we launch java directly The previous version assumed that a CLASSPATH environment variable was set by the "run" script when launching the process that starts the ExecutorRunner, but unfortunately this is not true in tests. Instead, we factor the classpath calculation into an extenral script and call that. NOTE: This includes a Windows version but hasn't yet been tested there. --- bin/compute-classpath.cmd | 52 +++++++++++++ bin/compute-classpath.sh | 89 ++++++++++++++++++++++ core/src/main/scala/spark/Utils.scala | 31 ++++++++ .../scala/spark/deploy/worker/ExecutorRunner.scala | 19 +---- run | 67 +++------------- run2.cmd | 38 +-------- 6 files changed, 189 insertions(+), 107 deletions(-) create mode 100644 bin/compute-classpath.cmd create mode 100755 bin/compute-classpath.sh diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd new file mode 100644 index 0000000000..1dff8fea22 --- /dev/null +++ b/bin/compute-classpath.cmd @@ -0,0 +1,52 @@ +@echo off + +rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" +rem script and the ExecutorRunner in standalone cluster mode. + +set SCALA_VERSION=2.9.3 + +rem Figure out where the Spark framework is installed +set FWDIR=%~dp0\.. + +rem Load environment variables from conf\spark-env.cmd, if it exists +if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" + +set CORE_DIR=%FWDIR%core +set REPL_DIR=%FWDIR%repl +set EXAMPLES_DIR=%FWDIR%examples +set BAGEL_DIR=%FWDIR%bagel +set STREAMING_DIR=%FWDIR%streaming +set PYSPARK_DIR=%FWDIR%python + +rem Build up classpath +set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes +set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources +set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes +set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\* +set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes +set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\* +set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\* +set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\* +set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\* +set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes + +rem Add hadoop conf dir - else FileSystem.*, etc fail +rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +rem the configurtion files. +if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir + set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% +:no_hadoop_conf_dir + +if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir + set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% +:no_yarn_conf_dir + +rem Add Scala standard library +set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar + +rem A bit of a hack to allow calling this script within run2.cmd without seeing output +if "x%DONT_PRINT_CLASSPATH%"=="x1" goto exit + +echo %CLASSPATH% + +:exit diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh new file mode 100755 index 0000000000..3a78880290 --- /dev/null +++ b/bin/compute-classpath.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# This script computes Spark's classpath and prints it to stdout; it's used by both the "run" +# script and the ExecutorRunner in standalone cluster mode. + +SCALA_VERSION=2.9.3 + +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +CORE_DIR="$FWDIR/core" +REPL_DIR="$FWDIR/repl" +REPL_BIN_DIR="$FWDIR/repl-bin" +EXAMPLES_DIR="$FWDIR/examples" +BAGEL_DIR="$FWDIR/bagel" +STREAMING_DIR="$FWDIR/streaming" +PYSPARK_DIR="$FWDIR/python" + +# Build up classpath +CLASSPATH="$SPARK_CLASSPATH" +CLASSPATH="$CLASSPATH:$FWDIR/conf" +CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes" +if [ -n "$SPARK_TESTING" ] ; then + CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" +fi +CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources" +CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar +if [ -e "$FWDIR/lib_managed" ]; then + CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*" + CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" +fi +CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" +# Add the shaded JAR for Maven builds +if [ -e $REPL_BIN_DIR/target ]; then + for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do + CLASSPATH="$CLASSPATH:$jar" + done + # The shaded JAR doesn't contain examples, so include those separately + EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` + CLASSPATH+=":$EXAMPLES_JAR" +fi +CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" +for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do + CLASSPATH="$CLASSPATH:$jar" +done + +# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack +# to avoid the -sources and -doc packages that are built by publish-local. +if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then + # Use the JAR from the SBT build + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar` +fi +if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then + # Use the JAR from the Maven build + export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` +fi + +# Add hadoop conf dir - else FileSystem.*, etc fail ! +# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +# the configurtion files. +if [ "x" != "x$HADOOP_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" +fi +if [ "x" != "x$YARN_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" +fi + +# Add Scala standard library +if [ -z "$SCALA_LIBRARY_PATH" ]; then + if [ -z "$SCALA_HOME" ]; then + echo "SCALA_HOME is not set" >&2 + exit 1 + fi + SCALA_LIBRARY_PATH="$SCALA_HOME/lib" +fi +CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar" +CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar" +CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar" + +echo "$CLASSPATH" diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index bdc1494cc9..f41efa9d29 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -522,6 +522,37 @@ private object Utils extends Logging { execute(command, new File(".")) } + /** + * Execute a command and get its output, throwing an exception if it yields a code other than 0. + */ + def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = { + val process = new ProcessBuilder(command: _*) + .directory(workingDir) + .start() + new Thread("read stderr for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + val output = new StringBuffer + val stdoutThread = new Thread("read stdout for " + command(0)) { + override def run() { + for (line <- Source.fromInputStream(process.getInputStream).getLines) { + output.append(line) + } + } + } + stdoutThread.start() + val exitCode = process.waitFor() + stdoutThread.join() // Wait for it to finish reading output + if (exitCode != 0) { + throw new SparkException("Process " + command + " exited with code " + exitCode) + } + output.toString + } + private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, val firstUserLine: Int, val firstUserClass: String) /** diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index db580e39ab..4f8e1dcb26 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -107,21 +107,9 @@ private[spark] class ExecutorRunner( val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") - var classPath = System.getenv("CLASSPATH") - if (System.getenv("SPARK_LAUNCH_WITH_SCALA") == "1") { - // Add the Scala library JARs to the classpath; this is needed when the ExecutorRunner - // was launched with "scala" as the runner (e.g. in spark-shell in local-cluster mode) - // and the Scala libraries won't be in the CLASSPATH environment variable by defalt. - if (System.getenv("SCALA_LIBRARY_PATH") == null && System.getenv("SCALA_HOME") == null) { - logError("Cloud not launch executors: neither SCALA_LIBRARY_PATH nor SCALA_HOME are set") - System.exit(1) - } - val scalaLib = Option(System.getenv("SCALA_LIBRARY_PATH")).getOrElse( - System.getenv("SCALA_HOME") + "/lib") - classPath += ":" + scalaLib + "/scala-library.jar" + - ":" + scalaLib + "/scala-compiler.jar" + - ":" + scalaLib + "/jline.jar" - } + // Figure out our classpath with the external compute-classpath script + val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" + val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext)) Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts } @@ -154,7 +142,6 @@ private[spark] class ExecutorRunner( // Launch the process val command = buildCommandSeq() - println("COMMAND: " + command.mkString(" ")) val builder = new ProcessBuilder(command: _*).directory(executorDir) val env = builder.environment() for ((key, value) <- appDesc.command.environment) { diff --git a/run b/run index 0fb15f8b24..7c06a55062 100755 --- a/run +++ b/run @@ -49,6 +49,12 @@ case "$1" in ;; esac +# Figure out whether to run our class with java or with the scala launcher. +# In most cases, we'd prefer to execute our process with java because scala +# creates a shell script as the parent of its Java process, which makes it +# hard to kill the child with stuff like Process.destroy(). However, for +# the Spark shell, the wrapper is necessary to properly reset the terminal +# when we exit, so we allow it to set a variable to launch with scala. if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then if [ "$SCALA_HOME" ]; then RUNNER="${SCALA_HOME}/bin/scala" @@ -98,12 +104,8 @@ export JAVA_OPTS # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! CORE_DIR="$FWDIR/core" -REPL_DIR="$FWDIR/repl" -REPL_BIN_DIR="$FWDIR/repl-bin" EXAMPLES_DIR="$FWDIR/examples" -BAGEL_DIR="$FWDIR/bagel" -STREAMING_DIR="$FWDIR/streaming" -PYSPARK_DIR="$FWDIR/python" +REPL_DIR="$FWDIR/repl" # Exit if the user hasn't compiled Spark if [ ! -e "$CORE_DIR/target" ]; then @@ -118,37 +120,9 @@ if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then exit 1 fi -# Build up classpath -CLASSPATH="$SPARK_CLASSPATH" -CLASSPATH="$CLASSPATH:$FWDIR/conf" -CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes" -if [ -n "$SPARK_TESTING" ] ; then - CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" -fi -CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources" -CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" -CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar -if [ -e "$FWDIR/lib_managed" ]; then - CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*" -fi -CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*" -# Add the shaded JAR for Maven builds -if [ -e $REPL_BIN_DIR/target ]; then - for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do - CLASSPATH="$CLASSPATH:$jar" - done - # The shaded JAR doesn't contain examples, so include those separately - EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` - CLASSPATH+=":$EXAMPLES_JAR" -fi -CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" -for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do - CLASSPATH="$CLASSPATH:$jar" -done +# Compute classpath using external script +CLASSPATH=`$FWDIR/bin/compute-classpath.sh` +export CLASSPATH # Figure out the JAR file that our examples were packaged into. This includes a bit of a hack # to avoid the -sources and -doc packages that are built by publish-local. @@ -161,32 +135,11 @@ if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar` fi -# Add hadoop conf dir - else FileSystem.*, etc fail ! -# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -# the configurtion files. -if [ "x" != "x$HADOOP_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR" -fi -if [ "x" != "x$YARN_CONF_DIR" ]; then - CLASSPATH="$CLASSPATH:$YARN_CONF_DIR" -fi - - -# Figure out whether to run our class with java or with the scala launcher. -# In most cases, we'd prefer to execute our process with java because scala -# creates a shell script as the parent of its Java process, which makes it -# hard to kill the child with stuff like Process.destroy(). However, for -# the Spark shell, the wrapper is necessary to properly reset the terminal -# when we exit, so we allow it to set a variable to launch with scala. if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS else - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar" - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar" - CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar" # The JVM doesn't read JAVA_OPTS by default so we need to pass it in EXTRA_ARGS="$JAVA_OPTS" fi -export CLASSPATH # Needed for spark-shell exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@" diff --git a/run2.cmd b/run2.cmd index bf76844d11..25e4f3b57c 100644 --- a/run2.cmd +++ b/run2.cmd @@ -33,51 +33,21 @@ if not "x%SCALA_HOME%"=="x" goto scala_exists goto exit :scala_exists -rem If the user specifies a Mesos JAR, put it before our included one on the classpath -set MESOS_CLASSPATH= -if not "x%MESOS_JAR%"=="x" set MESOS_CLASSPATH=%MESOS_JAR% - rem Figure out how much memory to use per executor and set it as an environment rem variable so that our process sees it and can report it to Mesos if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m rem Set JAVA_OPTS to be able to load native libraries and to set heap size set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% -rem Load extra JAVA_OPTS from conf/java-opts, if it exists -if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd" rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! set CORE_DIR=%FWDIR%core -set REPL_DIR=%FWDIR%repl set EXAMPLES_DIR=%FWDIR%examples -set BAGEL_DIR=%FWDIR%bagel -set STREAMING_DIR=%FWDIR%streaming -set PYSPARK_DIR=%FWDIR%python - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes -set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources -set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes -set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\* -set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes -set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\* -set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\* -set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\* -set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\* -set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - +set REPL_DIR=%FWDIR%repl +rem Compute classpath using external script +set DONT_PRINT_CLASSPATH=1 +call "%FWDIR%bin\compute-classpath.cmd" rem Figure out the JAR file that our examples were packaged into. rem First search in the build path from SBT: -- cgit v1.2.3 From f2263350eda780aba45f383b722e20702c310e6a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 18:35:35 -0400 Subject: Added a local-cluster mode test to ReplSuite --- repl/src/test/scala/spark/repl/ReplSuite.scala | 31 +++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index 1c64f9b98d..72ed8aca5b 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -35,17 +35,17 @@ class ReplSuite extends FunSuite { System.clearProperty("spark.hostPort") return out.toString } - + def assertContains(message: String, output: String) { assert(output contains message, "Interpreter output did not contain '" + message + "':\n" + output) } - + def assertDoesNotContain(message: String, output: String) { assert(!(output contains message), "Interpreter output contained '" + message + "':\n" + output) } - + test ("simple foreach with accumulator") { val output = runInterpreter("local", """ val accum = sc.accumulator(0) @@ -56,7 +56,7 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) assertContains("res1: Int = 55", output) } - + test ("external vars") { val output = runInterpreter("local", """ var v = 7 @@ -105,7 +105,7 @@ class ReplSuite extends FunSuite { assertContains("res0: Int = 70", output) assertContains("res1: Int = 100", output) } - + test ("broadcast vars") { // Test that the value that a broadcast var had when it was created is used, // even if that variable is then modified in the driver program @@ -143,6 +143,27 @@ class ReplSuite extends FunSuite { assertContains("res2: Long = 3", output) } + test ("local-cluster mode") { + val output = runInterpreter("local-cluster[1,1,512]", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { test ("running on Mesos") { val output = runInterpreter("localquiet", """ -- cgit v1.2.3 From 2bd04c3513ffa6deabc290a3931be946b1c18713 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 18:37:14 -0400 Subject: Formatting --- repl/src/test/scala/spark/repl/ReplSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index 72ed8aca5b..f46e6d8be4 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -28,8 +28,9 @@ class ReplSuite extends FunSuite { val separator = System.getProperty("path.separator") interp.process(Array("-classpath", paths.mkString(separator))) spark.repl.Main.interp = null - if (interp.sparkContext != null) + if (interp.sparkContext != null) { interp.sparkContext.stop() + } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") System.clearProperty("spark.hostPort") @@ -37,12 +38,12 @@ class ReplSuite extends FunSuite { } def assertContains(message: String, output: String) { - assert(output contains message, + assert(output.contains(message), "Interpreter output did not contain '" + message + "':\n" + output) } def assertDoesNotContain(message: String, output: String) { - assert(!(output contains message), + assert(!output.contains(message), "Interpreter output contained '" + message + "':\n" + output) } -- cgit v1.2.3 From 9f0d91329516c829c86fc8e95d02071ca7d1f186 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 19:18:30 -0400 Subject: Refactored tests to share SparkContexts in some of them Creating these seems to take a while and clutters the output with Akka stuff, so it would be nice to share them. --- core/src/test/scala/spark/CheckpointSuite.scala | 10 + .../test/scala/spark/PairRDDFunctionsSuite.scala | 287 ++++++++++++++++++++ core/src/test/scala/spark/PartitioningSuite.scala | 19 +- core/src/test/scala/spark/PipedRDDSuite.scala | 16 +- core/src/test/scala/spark/RDDSuite.scala | 87 +----- core/src/test/scala/spark/SharedSparkContext.scala | 25 ++ core/src/test/scala/spark/ShuffleSuite.scala | 298 +-------------------- core/src/test/scala/spark/SortingSuite.scala | 23 +- core/src/test/scala/spark/UnpersistSuite.scala | 30 +++ .../test/scala/spark/ZippedPartitionsSuite.scala | 3 +- 10 files changed, 373 insertions(+), 425 deletions(-) create mode 100644 core/src/test/scala/spark/PairRDDFunctionsSuite.scala create mode 100644 core/src/test/scala/spark/SharedSparkContext.scala create mode 100644 core/src/test/scala/spark/UnpersistSuite.scala diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index ca385972fb..28a7b21b92 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -27,6 +27,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } } + test("basic checkpointing") { + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + } + test("RDDs with one-to-one dependencies") { testCheckpointing(_.map(x => x.toString)) testCheckpointing(_.flatMap(x => 1 to x)) diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala new file mode 100644 index 0000000000..682d2745bf --- /dev/null +++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala @@ -0,0 +1,287 @@ +package spark + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import com.google.common.io.Files + +import spark.rdd.ShuffledRDD +import spark.SparkContext._ + +class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("groupByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with partitioner") { + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } + + test("join") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("leftOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + } + + test("rightOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + } + + test("join with no matches") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("groupWith") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + } + + test("zero-partition RDD") { + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } + + test("keys and values") { + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } + + test("default partitioner uses partition size") { + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitioner.get === p) + } + + test("foldByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } +} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index 16f93e71a3..99e433e3bd 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -6,8 +6,8 @@ import SparkContext._ import spark.util.StatCounter import scala.math.abs -class PartitioningSuite extends FunSuite with LocalSparkContext { - +class PartitioningSuite extends FunSuite with SharedSparkContext { + test("HashPartitioner equality") { val p2 = new HashPartitioner(2) val p4 = new HashPartitioner(4) @@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("RangePartitioner equality") { - sc = new SparkContext("local", "test") - // Make an RDD where all the elements are the same so that the partition range bounds // are deterministically all the same. val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) @@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("HashPartitioner not equal to RangePartitioner") { - sc = new SparkContext("local", "test") val rdd = sc.parallelize(1 to 10).map(x => (x, x)) val rangeP2 = new RangePartitioner(2, rdd) val hashP2 = new HashPartitioner(2) @@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("partitioner preservation") { - sc = new SparkContext("local", "test") - val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) val grouped2 = rdd.groupByKey(2) @@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { } test("partitioning Java arrays should fail") { - sc = new SparkContext("local", "test") val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) val arrPairs: RDD[(Array[Int], Int)] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) @@ -120,21 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) } - - test("Zero-length partitions should be correctly handled") { + + test("zero-length partitions should be correctly handled") { // Create RDD with some consecutive empty partitions (including the "first" one) - sc = new SparkContext("local", "test") val rdd: RDD[Double] = sc .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) .filter(_ >= 0.0) - + // Run the partitions, including the consecutive empty ones, through StatCounter val stats: StatCounter = rdd.stats(); assert(abs(6.0 - stats.sum) < 0.01); assert(abs(6.0/2 - rdd.mean) < 0.01); assert(abs(1.0 - rdd.variance) < 0.01); assert(abs(1.0 - rdd.stdev) < 0.01); - + // Add other tests here for classes that should be able to handle empty partitions correctly } } diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index ed075f93ec..1c9ca50811 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -3,10 +3,9 @@ package spark import org.scalatest.FunSuite import SparkContext._ -class PipedRDDSuite extends FunSuite with LocalSparkContext { - +class PipedRDDSuite extends FunSuite with SharedSparkContext { + test("basic pipe") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("cat")) @@ -20,12 +19,11 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { } test("advanced pipe") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val bl = sc.broadcast(List("0")) - val piped = nums.pipe(Seq("cat"), - Map[String, String](), + val piped = nums.pipe(Seq("cat"), + Map[String, String](), (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, (i:Int, f: String=> Unit) => f(i + "_")) @@ -43,8 +41,8 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), + pipe(Seq("cat"), + Map[String, String](), (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() assert(d.size === 8) @@ -59,7 +57,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { } test("pipe with env variable") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) val c = piped.collect() @@ -69,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext { } test("pipe with non-zero exit status") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val piped = nums.pipe("cat nonexistent_file") intercept[SparkException] { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 67f3332d44..d8db69b1c9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -7,10 +7,9 @@ import org.scalatest.time.{Span, Millis} import spark.SparkContext._ import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD} -class RDDSuite extends FunSuite with LocalSparkContext { +class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) @@ -46,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("SparkContext.union") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) @@ -55,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("aggregate") { - sc = new SparkContext("local", "test") val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] val emptyMap = new StringMap { @@ -75,57 +72,14 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } - test("basic checkpointing") { - import java.io.File - val checkpointDir = File.createTempFile("temp", "") - checkpointDir.delete() - - sc = new SparkContext("local", "test") - sc.setCheckpointDir(checkpointDir.toString) - val parCollection = sc.makeRDD(1 to 4) - val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) - val result = flatMappedRDD.collect() - Thread.sleep(1000) - assert(flatMappedRDD.dependencies.head.rdd != parCollection) - assert(flatMappedRDD.collect() === result) - - checkpointDir.deleteOnExit() - } - test("basic caching") { - sc = new SparkContext("local", "test") val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) assert(rdd.collect().toList === List(1, 2, 3, 4)) assert(rdd.collect().toList === List(1, 2, 3, 4)) } - test("unpersist RDD") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - rdd.count - assert(sc.persistentRdds.isEmpty === false) - rdd.unpersist() - assert(sc.persistentRdds.isEmpty === true) - - failAfter(Span(3000, Millis)) { - try { - while (! sc.getRDDStorageInfo.isEmpty) { - Thread.sleep(200) - } - } catch { - case _ => { Thread.sleep(10) } - // Do nothing. We might see exceptions because block manager - // is racing this thread to remove entries from the driver. - } - } - assert(sc.getRDDStorageInfo.isEmpty === true) - } - test("caching with failures") { - sc = new SparkContext("local", "test") val onlySplit = new Partition { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -148,7 +102,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("empty RDD") { - sc = new SparkContext("local", "test") val empty = new EmptyRDD[Int](sc) assert(empty.count === 0) assert(empty.collect().size === 0) @@ -168,37 +121,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("cogrouped RDDs") { - sc = new SparkContext("local", "test") - val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2) - val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2) - - // Use cogroup function - val cogrouped = rdd1.cogroup(rdd2).collectAsMap() - assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped(2) === (Seq("two"), Seq("two1"))) - assert(cogrouped(3) === (Seq("three"), Seq())) - - // Construct CoGroupedRDD directly, with map side combine enabled - val cogrouped1 = new CoGroupedRDD[Int]( - Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]), - new HashPartitioner(3), - true).collectAsMap() - assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1"))) - assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq())) - - // Construct CoGroupedRDD directly, with map side combine disabled - val cogrouped2 = new CoGroupedRDD[Int]( - Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]), - new HashPartitioner(3), - false).collectAsMap() - assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1"))) - assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1"))) - assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq())) - } - - test("coalesced RDDs") { - sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) val coalesced1 = data.coalesce(2) @@ -236,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("zipped RDDs") { - sc = new SparkContext("local", "test") val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val zipped = nums.zip(nums.map(_ + 1.0)) assert(zipped.glom().map(_.toList).collect().toList === @@ -248,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("partition pruning") { - sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) @@ -260,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("mapWith") { import java.util.Random - sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.mapWith( (index: Int) => new Random(index + 42)) @@ -279,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("flatMapWith") { import java.util.Random - sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.flatMapWith( (index: Int) => new Random(index + 42)) @@ -301,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("filterWith") { import java.util.Random - sc = new SparkContext("local", "test") val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) val sample = ints.filterWith( (index: Int) => new Random(index + 42)) @@ -319,7 +236,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("top with predefined ordering") { - sc = new SparkContext("local", "test") val nums = Array.range(1, 100000) val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) val topK = ints.top(5) @@ -328,7 +244,6 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("top with custom ordering") { - sc = new SparkContext("local", "test") val words = Vector("a", "b", "c", "d") implicit val ord = implicitly[Ordering[String]].reverse val rdd = sc.makeRDD(words, 2) diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala new file mode 100644 index 0000000000..1da79f9824 --- /dev/null +++ b/core/src/test/scala/spark/SharedSparkContext.scala @@ -0,0 +1,25 @@ +package spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ +trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => + + @transient private var _sc: SparkContext = _ + + def sc: SparkContext = _sc + + override def beforeAll() { + _sc = new SparkContext("local", "test") + super.beforeAll() + } + + override def afterAll() { + if (_sc != null) { + LocalSparkContext.stop(_sc) + _sc = null + } + super.afterAll() + } +} diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 0c1ec29f96..950218fa28 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD import spark.SparkContext._ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - - test("groupByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with duplicates") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with negative key hash codes") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesForMinus1 = groups.find(_._1 == -1).get._2 - assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with many output partitions") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey(10).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - test("groupByKey with compression") { try { - System.setProperty("spark.blockManager.compress", "true") + System.setProperty("spark.shuffle.compress", "true") sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) val groups = pairs.groupByKey(4).collect() @@ -77,234 +32,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } } - test("reduceByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with collectAsMap") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() - assert(sums.size === 2) - assert(sums(1) === 7) - assert(sums(2) === 1) - } - - test("reduceByKey with many output partitons") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with partitioner") { - sc = new SparkContext("local", "test") - val p = new Partitioner() { - def numPartitions = 2 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) - assert(sums.collect().toSet === Set((1, 4), (0, 1))) - assert(sums.partitioner === Some(p)) - // count the dependencies to make sure there is only 1 ShuffledRDD - val deps = new HashSet[RDD[_]]() - def visit(r: RDD[_]) { - for (dep <- r.dependencies) { - deps += dep.rdd - visit(dep.rdd) - } - } - visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection - } - - test("join") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("join all-to-all") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) - val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 6) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (1, 'y')), - (1, (2, 'x')), - (1, (2, 'y')), - (1, (3, 'x')), - (1, (3, 'y')) - )) - } - - test("leftOuterJoin") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.leftOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (1, Some('x'))), - (1, (2, Some('x'))), - (2, (1, Some('y'))), - (2, (1, Some('z'))), - (3, (1, None)) - )) - } - - test("rightOuterJoin") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.rightOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (Some(1), 'x')), - (1, (Some(2), 'x')), - (2, (Some(1), 'y')), - (2, (Some(1), 'z')), - (4, (None, 'w')) - )) - } - - test("join with no matches") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 0) - } - - test("join with many output partitions") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2, 10).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("groupWith") { - sc = new SparkContext("local", "test") - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.groupWith(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) - )) - } - - test("zero-partition RDD") { - sc = new SparkContext("local", "test") - val emptyDir = Files.createTempDir() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - } - - test("keys and values") { - sc = new SparkContext("local", "test") - val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) - assert(rdd.keys.collect().toList === List(1, 2)) - assert(rdd.values.collect().toList === List("a", "b")) - } - - test("default partitioner uses partition size") { - sc = new SparkContext("local", "test") - // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) - // do a map, which loses the partitioner - val b = a.map(a => (a, (a * 2).toString)) - // then a group by, and see we didn't revert to 2 partitions - val c = b.groupByKey() - assert(c.partitions.size === 2000) - } - - test("default partitioner uses largest partitioner") { - sc = new SparkContext("local", "test") - val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) - val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) - val c = a.join(b) - assert(c.partitions.size === 2000) - } - - test("subtract") { - sc = new SparkContext("local", "test") - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set(1)) - assert(c.partitions.size === a.partitions.size) - } - - test("subtract with narrow dependency") { - sc = new SparkContext("local", "test") - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set((1, "a"), (3, "c"))) - // Ideally we could keep the original partitioner... - assert(c.partitioner === None) - } - - test("subtractByKey") { - sc = new SparkContext("local", "test") - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitions.size === a.partitions.size) - } - - test("subtractByKey with narrow dependency") { - sc = new SparkContext("local", "test") - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitioner.get === p) - } - test("shuffle non-zero block size") { sc = new SparkContext("local-cluster[2,1,512]", "test") val NUM_BLOCKS = 3 @@ -391,29 +118,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // We should have at most 4 non-zero sized partitions assert(nonEmptyBlocks.size <= 4) } - - test("foldByKey") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("foldByKey with mutable result type") { - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() - // Fold the values using in-place mutation - val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() - assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) - // Check that the mutable objects in the original RDD were not changed - assert(bufs.collect().toSet === Set( - (1, ArrayBuffer(1)), - (1, ArrayBuffer(2)), - (1, ArrayBuffer(3)), - (1, ArrayBuffer(1)), - (2, ArrayBuffer(1)))) - } } object ShuffleSuite { diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 495f957e53..f7bf207c68 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging { - +class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { + test("sortByKey") { - sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) } test("large array") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("large array with one split") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w assert(sorted.partitions.size === 1) assert(sorted.collect() === pairArr.sortBy(_._1)) } - + test("large array with many partitions") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w assert(sorted.partitions.size === 20) assert(sorted.collect() === pairArr.sortBy(_._1)) } - + test("sort descending") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("sort descending with one split") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 1) assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } - + test("sort descending with many partitions") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) @@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("more partitions than elements") { - sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 30) @@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("empty RDD") { - sc = new SparkContext("local", "test") val pairArr = new Array[(Int, Int)](0) val pairs = sc.parallelize(pairArr, 2) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) } test("partition balancing") { - sc = new SparkContext("local", "test") val pairArr = (1 to 1000).map(x => (x, x)).toArray val sorted = sc.parallelize(pairArr, 4).sortByKey() assert(sorted.collect() === pairArr.sortBy(_._1)) @@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w } test("partition balancing for descending sort") { - sc = new SparkContext("local", "test") val pairArr = (1 to 1000).map(x => (x, x)).toArray val sorted = sc.parallelize(pairArr, 4).sortByKey(false) assert(sorted.collect() === pairArr.sortBy(_._1).reverse) diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala new file mode 100644 index 0000000000..94776e7572 --- /dev/null +++ b/core/src/test/scala/spark/UnpersistSuite.scala @@ -0,0 +1,30 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import spark.SparkContext._ + +class UnpersistSuite extends FunSuite with LocalSparkContext { + test("unpersist RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + rdd.count + assert(sc.persistentRdds.isEmpty === false) + rdd.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + assert(sc.getRDDStorageInfo.isEmpty === true) + } +} diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala index 5f60aa75d7..96cb295f45 100644 --- a/core/src/test/scala/spark/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala @@ -17,9 +17,8 @@ object ZippedPartitionsSuite { } } -class ZippedPartitionsSuite extends FunSuite with LocalSparkContext { +class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { test("print sizes") { - sc = new SparkContext("local", "test") val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) val data3 = sc.makeRDD(Array(1.0, 2.0), 2) -- cgit v1.2.3 From 32370da4e40062b88c921417cd7418d804e16f23 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 25 Jun 2013 22:08:19 -0400 Subject: Don't use forward slash in exclusion for JAR signature files --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 484f97d992..07572201de 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -236,7 +236,7 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard - case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard + case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } -- cgit v1.2.3 From d11025dc6aedd9763cfd2390e8daf24747d17258 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 26 Jun 2013 09:53:35 -0500 Subject: Be cute with Option and getenv. --- .../scala/spark/deploy/worker/ExecutorRunner.scala | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 4f8e1dcb26..a9b12421e6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -1,6 +1,7 @@ package spark.deploy.worker import java.io._ +import java.lang.System.getenv import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription} import akka.actor.ActorRef import spark.{Utils, Logging} @@ -77,11 +78,7 @@ private[spark] class ExecutorRunner( def buildCommandSeq(): Seq[String] = { val command = appDesc.command - val runner = if (System.getenv("JAVA_HOME") == null) { - "java" - } else { - System.getenv("JAVA_HOME") + "/bin/java" - } + val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java") // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ @@ -93,18 +90,8 @@ private[spark] class ExecutorRunner( * the way the JAVA_OPTS are assembled there. */ def buildJavaOpts(): Seq[String] = { - val libraryOpts = if (System.getenv("SPARK_LIBRARY_PATH") == null) { - Nil - } else { - List("-Djava.library.path=" + System.getenv("SPARK_LIBRARY_PATH")) - } - - val userOpts = if (System.getenv("SPARK_JAVA_OPTS") == null) { - Nil - } else { - Utils.splitCommandString(System.getenv("SPARK_JAVA_OPTS")) - } - + val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH")).map(p => List("-Djava.library.path=" + p)).getOrElse(Nil) + val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") // Figure out our classpath with the external compute-classpath script -- cgit v1.2.3 From d7011632d15b6de2129d360277e304fddb8f2aac Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 26 Jun 2013 12:35:57 -0500 Subject: Wrap lines. --- core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index a9b12421e6..d7f58b2cb1 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -90,7 +90,9 @@ private[spark] class ExecutorRunner( * the way the JAVA_OPTS are assembled there. */ def buildJavaOpts(): Seq[String] = { - val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH")).map(p => List("-Djava.library.path=" + p)).getOrElse(Nil) + val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH")) + .map(p => List("-Djava.library.path=" + p)) + .getOrElse(Nil) val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") -- cgit v1.2.3 From 03906f7f0a8c93b09e3b47ccaad5b5f72c29302b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 26 Jun 2013 17:40:22 -0700 Subject: Fixes to compute-classpath on Windows --- bin/compute-classpath.cmd | 4 ++-- run2.cmd | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 1dff8fea22..6e7efbd334 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -6,7 +6,7 @@ rem script and the ExecutorRunner in standalone cluster mode. set SCALA_VERSION=2.9.3 rem Figure out where the Spark framework is installed -set FWDIR=%~dp0\.. +set FWDIR=%~dp0..\ rem Load environment variables from conf\spark-env.cmd, if it exists if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" @@ -45,7 +45,7 @@ rem Add Scala standard library set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "x%DONT_PRINT_CLASSPATH%"=="x1" goto exit +if "%DONT_PRINT_CLASSPATH%"=="1" goto exit echo %CLASSPATH% diff --git a/run2.cmd b/run2.cmd index 25e4f3b57c..a9c4df180f 100644 --- a/run2.cmd +++ b/run2.cmd @@ -48,6 +48,7 @@ set REPL_DIR=%FWDIR%repl rem Compute classpath using external script set DONT_PRINT_CLASSPATH=1 call "%FWDIR%bin\compute-classpath.cmd" +set DONT_PRINT_CLASSPATH=0 rem Figure out the JAR file that our examples were packaged into. rem First search in the build path from SBT: -- cgit v1.2.3 From aea727f68d5fe5e81fc04ece97ad94c6f12c7270 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 26 Jun 2013 21:14:38 -0700 Subject: Simplify Python docs a little to do substring search --- docs/python-programming-guide.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 3a7a8db4a6..7f1e7cf93d 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -27,14 +27,14 @@ Short functions can be passed to RDD methods using Python's [`lambda`](http://ww {% highlight python %} logData = sc.textFile(logFile).cache() -errors = logData.filter(lambda s: 'ERROR' in s.split()) +errors = logData.filter(lambda line: "ERROR" in line) {% endhighlight %} You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`: {% highlight python %} def is_error(line): - return 'ERROR' in line.split() + return "ERROR" in line errors = logData.filter(is_error) {% endhighlight %} @@ -43,8 +43,7 @@ Functions can access objects in enclosing scopes, although modifications to thos {% highlight python %} error_keywords = ["Exception", "Error"] def is_error(line): - words = line.split() - return any(keyword in words for keyword in error_keywords) + return any(keyword in line for keyword in error_keywords) errors = logData.filter(is_error) {% endhighlight %} -- cgit v1.2.3 From c767e7437059aa35bca3e8bb93264b35853c7a8f Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 27 Jun 2013 21:48:58 -0700 Subject: Removing incorrect test statement --- core/src/test/scala/spark/scheduler/JobLoggerSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala index 4000c4d520..699901f1a1 100644 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala @@ -41,7 +41,6 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID) joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4)) - joblogger.getEventQueue.size should be (1) joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) parentRdd.setName("MyRDD") joblogger.getRddNameTest(parentRdd) should be ("MyRDD") -- cgit v1.2.3 From 4974b658edf2716ff3c6f2e6863cddb2a4ddf891 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 27 Jun 2013 22:16:40 -0700 Subject: Look at JAVA_HOME before PATH to determine Java executable --- run | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run b/run index 7c06a55062..805466ea2c 100755 --- a/run +++ b/run @@ -67,14 +67,15 @@ if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then fi fi else - if [ `command -v java` ]; then - RUNNER="java" + if [ -n "${JAVA_HOME}" ]; then + RUNNER="${JAVA_HOME}/bin/java" else - if [ -z "$JAVA_HOME" ]; then + if [ `command -v java` ]; then + RUNNER="java" + else echo "JAVA_HOME is not set" >&2 exit 1 fi - RUNNER="${JAVA_HOME}/bin/java" fi if [ -z "$SCALA_LIBRARY_PATH" ]; then if [ -z "$SCALA_HOME" ]; then -- cgit v1.2.3 From 4358acfe07e991090fbe009aafe3f5110fbf0c40 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 29 Jun 2013 15:25:06 -0700 Subject: Initialize Twitter4J OAuth from system properties instead of prompting --- .../scala/spark/streaming/StreamingContext.scala | 4 ++- .../streaming/api/java/JavaStreamingContext.scala | 23 +++++++--------- .../streaming/dstream/TwitterInputDStream.scala | 32 ++++++---------------- .../test/java/spark/streaming/JavaAPISuite.java | 2 +- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index e61438fe3a..36b841af8f 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -381,7 +381,9 @@ class StreamingContext private ( /** * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J authentication + * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth + * authorization; this uses the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret. * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index c4a223b419..ed7b789d98 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -307,7 +307,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create a input stream that returns tweets received from Twitter. - * @param twitterAuth Twitter4J Authorization + * @param twitterAuth Twitter4J Authorization object * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ @@ -320,10 +320,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. * @param filters Set of filter strings to get only those tweets that match them * @param storageLevel Storage level to use for storing the received objects */ @@ -347,10 +346,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. * @param filters Set of filter strings to get only those tweets that match them */ def twitterStream( @@ -370,10 +368,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Create a input stream that returns tweets received from Twitter using - * java.util.Preferences to store OAuth token. OAuth key and secret should - * be provided using system properties twitter4j.oauth.consumerKey and - * twitter4j.oauth.consumerSecret + * Create a input stream that returns tweets received from Twitter using Twitter4J's default + * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, + * .consumerSecret, .accessToken and .accessTokenSecret to be set. */ def twitterStream(): JavaDStream[Status] = { ssc.twitterStream() diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index e0c654d385..ff7a58be45 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -4,21 +4,21 @@ import spark._ import spark.streaming._ import storage.StorageLevel import twitter4j._ -import twitter4j.auth.BasicAuthorization import twitter4j.auth.Authorization import java.util.prefs.Preferences +import twitter4j.conf.ConfigurationBuilder import twitter4j.conf.PropertyConfiguration import twitter4j.auth.OAuthAuthorization import twitter4j.auth.AccessToken /* A stream of Twitter statuses, potentially filtered by one or more keywords. * -* @constructor create a new Twitter stream using the supplied username and password to authenticate. +* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. * -* Includes a simple implementation of OAuth using consumer key and secret provided using system -* properties twitter4j.oauth.consumerKey and twitter4j.oauth.consumerSecret +* If no Authorization object is provided, initializes OAuth authorization using the system +* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. */ private[streaming] class TwitterInputDStream( @@ -28,28 +28,14 @@ class TwitterInputDStream( storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - lazy val createOAuthAuthorization: Authorization = { - val userRoot = Preferences.userRoot(); - val token = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN, null)) - val tokenSecret = Option(userRoot.get(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, null)) - val oAuth = new OAuthAuthorization(new PropertyConfiguration(System.getProperties())) - if (token.isEmpty || tokenSecret.isEmpty) { - val requestToken = oAuth.getOAuthRequestToken() - println("Authorize application using URL: "+requestToken.getAuthorizationURL()) - println("Enter PIN: ") - val pin = Console.readLine - val accessToken = if (pin.length() > 0) oAuth.getOAuthAccessToken(requestToken, pin) else oAuth.getOAuthAccessToken() - userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN, accessToken.getToken()) - userRoot.put(PropertyConfiguration.OAUTH_ACCESS_TOKEN_SECRET, accessToken.getTokenSecret()) - userRoot.flush() - } else { - oAuth.setOAuthAccessToken(new AccessToken(token.get, tokenSecret.get)); - } - oAuth + private def createOAuthAuthorization(): Authorization = { + new OAuthAuthorization(new ConfigurationBuilder().build()) } + + private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) override def getReceiver(): NetworkReceiver[Status] = { - new TwitterReceiver(if (twitterAuth.isEmpty) createOAuthAuthorization else twitterAuth.get, filters, storageLevel) + new TwitterReceiver(authorization, filters, storageLevel) } } diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index e5fdbe1b7a..4cf10582a9 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -1267,7 +1267,7 @@ public class JavaAPISuite implements Serializable { @Test public void testTwitterStream() { String[] filters = new String[] { "good", "bad", "ugly" }; - JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY()); + JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY()); } @Test -- cgit v1.2.3 From 5cfcd3c336cc13e9fd448ae122216e4b583b77b4 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 29 Jun 2013 15:37:27 -0700 Subject: Remove Twitter4J specific repo since it's in Maven central --- pom.xml | 11 ----------- project/SparkBuild.scala | 3 +-- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/pom.xml b/pom.xml index 3bcb2a3f34..7a31be98b2 100644 --- a/pom.xml +++ b/pom.xml @@ -109,17 +109,6 @@ false - - twitter4j-repo - Twitter4J Repository - http://twitter4j.org/maven2/ - - true - - - false - - diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 07572201de..5e4692162e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -138,8 +138,7 @@ object SparkBuild extends Build { resolvers ++= Seq( "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", "Spray Repository" at "http://repo.spray.cc/", - "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", - "Twitter4J Repository" at "http://twitter4j.org/maven2/" + "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" ), libraryDependencies ++= Seq( -- cgit v1.2.3 From 03d0b858c807339b4221bedffa29ac76eef5352e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 30 Jun 2013 15:38:58 -0700 Subject: Made use of spark.executor.memory setting consistent and documented it Conflicts: core/src/main/scala/spark/SparkContext.scala --- core/src/main/scala/spark/SparkContext.scala | 24 +++++++++++------ .../spark/scheduler/cluster/SchedulerBackend.scala | 11 ++------ docs/configuration.md | 31 ++++++++++++++-------- docs/ec2-scripts.md | 5 ++-- docs/tuning.md | 6 ++--- 5 files changed, 43 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 70a9d7698c..366afb2a2a 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -115,13 +115,14 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner - for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value } } + // Since memory can be set with a system property too, use that + executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" if (environment != null) { executorEnvs ++= environment } @@ -156,14 +157,12 @@ class SparkContext( scheduler case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang. + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt - val sparkMemEnv = System.getenv("SPARK_MEM") - val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512 - if (sparkMemEnvInt > memoryPerSlaveInt) { + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { throw new SparkException( - "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format( - memoryPerSlaveInt, sparkMemEnvInt)) + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } val scheduler = new ClusterScheduler(this) @@ -881,6 +880,15 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) + + /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ + private[spark] val executorMemoryRequested = { + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index 9ac875de3a..8844057a5c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import spark.Utils +import spark.{SparkContext, Utils} /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -14,14 +14,7 @@ private[spark] trait SchedulerBackend { def defaultParallelism(): Int // Memory used by each executor (in megabytes) - protected val executorMemory = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } - + protected val executorMemory: Int = SparkContext.executorMemoryRequested // TODO: Probably want to add a killTask too } diff --git a/docs/configuration.md b/docs/configuration.md index 2de512f896..ae61769e31 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -25,23 +25,25 @@ Inside `spark-env.sh`, you *must* set at least the following two variables: * `SCALA_HOME`, to point to your Scala installation. * `MESOS_NATIVE_LIBRARY`, if you are [running on a Mesos cluster](running-on-mesos.html). -In addition, there are four other variables that control execution. These can be set *either in `spark-env.sh` -or in each job's driver program*, because they will automatically be propagated to workers from the driver. -For a multi-user environment, we recommend setting the in the driver program instead of `spark-env.sh`, so -that different user jobs can use different amounts of memory, JVM options, etc. +In addition, there are four other variables that control execution. These should be set *in the environment that +launches the job's driver program* instead of `spark-env.sh`, because they will be automatically propagated to +workers. Setting these per-job instead of in `spark-env.sh` ensures that different jobs can have different settings +for these variables. -* `SPARK_MEM`, to set the amount of memory used per node (this should be in the same format as the - JVM's -Xmx option, e.g. `300m` or `1g`) * `SPARK_JAVA_OPTS`, to add JVM options. This includes any system properties that you'd like to pass with `-D`. * `SPARK_CLASSPATH`, to add elements to Spark's classpath. * `SPARK_LIBRARY_PATH`, to add search directories for native libraries. +* `SPARK_MEM`, to set the amount of memory used per node. This should be in the same format as the + JVM's -Xmx option, e.g. `300m` or `1g`. Note that this option will soon be deprecated in favor of + the `spark.executor.memory` system property, so we recommend using that in new code. -Note that if you do set these in `spark-env.sh`, they will override the values set by user programs, which -is undesirable; you can choose to have `spark-env.sh` set them only if the user program hasn't, as follows: +Beware that if you do set these variables in `spark-env.sh`, they will override the values set by user programs, +which is undesirable; if you prefer, you can choose to have `spark-env.sh` set them only if the user program +hasn't, as follows: {% highlight bash %} -if [ -z "$SPARK_MEM" ] ; then - SPARK_MEM="1g" +if [ -z "$SPARK_JAVA_OPTS" ] ; then + SPARK_JAVA_OPTS="-verbose:gc" fi {% endhighlight %} @@ -55,10 +57,17 @@ val sc = new SparkContext(...) {% endhighlight %} Most of the configurable system properties control internal settings that have reasonable default values. However, -there are at least four properties that you will commonly want to control: +there are at least five properties that you will commonly want to control: + + + + + diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index dc57035eba..eab8a0ff20 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -106,9 +106,8 @@ permissions on your private key file, you can run `launch` with the # Configuration You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such -as JVM options and, most crucially, the amount of memory to use per machine (`SPARK_MEM`). -This file needs to be copied to **every machine** to reflect the change. The easiest way to do this -is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, +as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to +do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers. The [configuration guide](configuration.html) describes the available configuration options. diff --git a/docs/tuning.md b/docs/tuning.md index 32c7ab86e9..5ffca54481 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -157,9 +157,9 @@ their work directories), *not* on your driver program. **Cache Size Tuning** -One important configuration parameter for GC is the amount of memory that should be used for -caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that - 33% of memory is available for any objects created during task execution. +One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. +By default, Spark uses 66% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to +cache RDDs. This means that 33% of memory is available for any objects created during task execution. In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call -- cgit v1.2.3 From 5bbd0eec84867937713ceb8438f25a943765a084 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 30 Jun 2013 17:00:26 -0700 Subject: Update docs on SCALA_LIBRARY_PATH --- conf/spark-env.sh.template | 18 ++++++------------ docs/configuration.md | 4 +++- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 37565ca827..b8936314ec 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,8 +3,10 @@ # This file contains environment variables required to run Spark. Copy it as # spark-env.sh and edit that to configure Spark for your site. At a minimum, # the following two variables should be set: -# - MESOS_NATIVE_LIBRARY, to point to your Mesos native library (libmesos.so) -# - SCALA_HOME, to point to your Scala installation +# - SCALA_HOME, to point to your Scala installation, or SCALA_LIBRARY_PATH to +# point to the directory for Scala library JARs (if you install Scala as a +# Debian or RPM package, these are in a separate path, often /usr/share/java) +# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos # # If using the standalone deploy mode, you can also set variables for it: # - SPARK_MASTER_IP, to bind the master to a different IP address @@ -12,14 +14,6 @@ # - SPARK_WORKER_CORES, to set the number of cores to use on this machine # - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g) # - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT -# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine -# -# Finally, Spark also relies on the following variables, but these can be set -# on just the *master* (i.e. in your driver program), and will automatically -# be propagated to workers: -# - SPARK_MEM, to change the amount of memory used per node (this should -# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g) -# - SPARK_CLASSPATH, to add elements to Spark's classpath -# - SPARK_JAVA_OPTS, to add JVM options -# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries. +# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes +# to be spawned on every slave machine diff --git a/docs/configuration.md b/docs/configuration.md index ae61769e31..3266db7af1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -22,7 +22,9 @@ the copy executable. Inside `spark-env.sh`, you *must* set at least the following two variables: -* `SCALA_HOME`, to point to your Scala installation. +* `SCALA_HOME`, to point to your Scala installation, or `SCALA_LIBRARY_PATH` to point to the directory for Scala + library JARs (if you install Scala as a Debian or RPM package, there is no `SCALA_HOME`, but these libraries + are in a separate path, typically /usr/share/java; look for `scala-library.jar`). * `MESOS_NATIVE_LIBRARY`, if you are [running on a Mesos cluster](running-on-mesos.html). In addition, there are four other variables that control execution. These should be set *in the environment that -- cgit v1.2.3 From 39ae073b5cd0dcfe4a00d9f205c88bad9df37870 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 30 Jun 2013 17:11:14 -0700 Subject: Increase SLF4j version in Maven too --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 7a31be98b2..48e623fa1c 100644 --- a/pom.xml +++ b/pom.xml @@ -56,7 +56,7 @@ 2.0.3 1.0-M2.1 1.1.1 - 1.6.1 + 1.7.2 4.1.2 1.2.17 -- cgit v1.2.3 From 3296d132b6ce042843de6e7384800e089b49e5fa Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Jul 2013 02:45:00 +0000 Subject: Fix performance bug with new Python code not using buffered streams --- core/src/main/scala/spark/SparkEnv.scala | 3 +- .../main/scala/spark/api/python/PythonRDD.scala | 33 +++++++++++----------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7ccde2e818..ec59b4f48f 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -59,7 +59,8 @@ class SparkEnv ( def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create() + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 63140cf37f..3f283afa62 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -45,37 +45,38 @@ private[spark] class PythonRDD[T: ClassManifest]( new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(worker.getOutputStream) - val dOut = new DataOutputStream(worker.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream) + val dataOut = new DataOutputStream(stream) + val printOut = new PrintWriter(stream) // Partition index - dOut.writeInt(split.index) + dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) // Broadcast variables - dOut.writeInt(broadcastVars.length) + dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - dOut.writeLong(broadcast.id) - dOut.writeInt(broadcast.value.length) - dOut.write(broadcast.value) - dOut.flush() + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) } + dataOut.flush() // Serialized user code for (elem <- command) { - out.println(elem) + printOut.println(elem) } - out.flush() + printOut.flush() // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dOut) + PythonRDD.writeAsPickle(elem, dataOut) } - dOut.flush() - out.flush() + dataOut.flush() + printOut.flush() worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(worker.getInputStream) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj @@ -288,7 +289,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(socket.getOutputStream) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) -- cgit v1.2.3 From ec31e68d5df259e6df001529235d8c906ff02a6f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Jul 2013 06:20:14 +0000 Subject: Fixed PySpark perf regression by not using socket.makefile(), and improved debuggability by letting "print" statements show up in the executor's stderr Conflicts: core/src/main/scala/spark/api/python/PythonRDD.scala --- .../main/scala/spark/api/python/PythonRDD.scala | 10 ++++-- .../spark/api/python/PythonWorkerFactory.scala | 20 ++++++++++- python/pyspark/daemon.py | 42 ++++++++++++---------- 3 files changed, 50 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 3f283afa62..31d8ea89d4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: JMap[String, String], @@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) val printOut = new PrintWriter(stream) // Partition index @@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest]( }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj @@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList @@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala index 8844411d73..85d1dfeac8 100644 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars) daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt() // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { @@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } }.start() + + val in = new DataInputStream(daemon.getInputStream) + daemonPort = in.readInt() + + // Redirect further stdout output to our stderr + new Thread("stdout reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() } catch { case e => { stopDaemon() diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 78a2da1e18..78c9457b84 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -1,10 +1,13 @@ import os +import signal +import socket import sys +import traceback import multiprocessing from ctypes import c_bool from errno import EINTR, ECHILD -from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN -from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from socket import AF_INET, SOCK_STREAM, SOMAXCONN +from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main from pyspark.serializers import write_int @@ -33,11 +36,12 @@ def compute_real_exit_code(exit_code): def worker(listen_sock): # Redirect stdout to stderr os.dup2(2, 1) + sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 # Manager sends SIGHUP to request termination of workers in the pool def handle_sighup(*args): assert should_exit() - signal(SIGHUP, handle_sighup) + signal.signal(SIGHUP, handle_sighup) # Cleanup zombie children def handle_sigchld(*args): @@ -51,7 +55,7 @@ def worker(listen_sock): handle_sigchld() elif err.errno != ECHILD: raise - signal(SIGCHLD, handle_sigchld) + signal.signal(SIGCHLD, handle_sigchld) # Handle clients while not should_exit(): @@ -70,19 +74,22 @@ def worker(listen_sock): # never receives SIGCHLD unless a worker crashes. if os.fork() == 0: # Leave the worker pool - signal(SIGHUP, SIG_DFL) + signal.signal(SIGHUP, SIG_DFL) listen_sock.close() - # Handle the client then exit - sockfile = sock.makefile() + # Read the socket using fdopen instead of socket.makefile() because the latter + # seems to be very slow; note that we need to dup() the file descriptor because + # otherwise writes also cause a seek that makes us miss data on the read side. + infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - worker_main(sockfile, sockfile) + worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = exc.code finally: - sockfile.close() - sock.close() - os._exit(compute_real_exit_code(exit_code)) + outfile.flush() + sock.close() + os._exit(compute_real_exit_code(exit_code)) else: sock.close() @@ -92,7 +99,6 @@ def launch_worker(listen_sock): try: worker(listen_sock) except Exception as err: - import traceback traceback.print_exc() os._exit(1) else: @@ -105,7 +111,7 @@ def manager(): os.setpgid(0, 0) # Create a listening socket on the AF_INET loopback interface - listen_sock = socket(AF_INET, SOCK_STREAM) + listen_sock = socket.socket(AF_INET, SOCK_STREAM) listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() @@ -121,8 +127,8 @@ def manager(): exit_flag.value = True # Gracefully exit on SIGTERM, don't die on SIGHUP - signal(SIGTERM, lambda signum, frame: shutdown()) - signal(SIGHUP, SIG_IGN) + signal.signal(SIGTERM, lambda signum, frame: shutdown()) + signal.signal(SIGHUP, SIG_IGN) # Cleanup zombie children def handle_sigchld(*args): @@ -133,7 +139,7 @@ def manager(): except EnvironmentError as err: if err.errno not in (ECHILD, EINTR): raise - signal(SIGCHLD, handle_sigchld) + signal.signal(SIGCHLD, handle_sigchld) # Initialization complete sys.stdout.close() @@ -148,7 +154,7 @@ def manager(): shutdown() raise finally: - signal(SIGTERM, SIG_DFL) + signal.signal(SIGTERM, SIG_DFL) exit_flag.value = True # Send SIGHUP to notify workers of shutdown os.kill(0, SIGHUP) -- cgit v1.2.3 From 7cd490ef5ba28df31f5e061eff83c855731dfca4 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 1 Jul 2013 06:25:17 +0000 Subject: Clarify that PySpark is not supported on Windows --- docs/python-programming-guide.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 7f1e7cf93d..e8aaac74d0 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -17,10 +17,9 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - Special functions on RDDs of doubles, such as `mean` and `stdev` - - `lookup` + - `lookup`, `sample` and `sort` - `persist` at storage levels other than `MEMORY_ONLY` - - `sample` - - `sort` + - Execution on Windows -- this is slated for a future release In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types. Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax: -- cgit v1.2.3 From 6fdbc68f2c5b220d1618d5a78d46aa0f844cae45 Mon Sep 17 00:00:00 2001 From: Konstantin Boudnik Date: Mon, 1 Jul 2013 16:05:55 -0700 Subject: Fixing missed hbase dependency in examples hadoop2-yarn profile --- examples/pom.xml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/pom.xml b/examples/pom.xml index 3e5271ec2f..78ec58729b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -193,6 +193,11 @@ hadoop-yarn-common provided + + org.apache.hbase + hbase + 0.94.6 + -- cgit v1.2.3
    Property NameDefaultMeaning
    spark.executor.memory512m + Amount of memory to use per executor process, in the same format as JVM memory strings (e.g. `512m`, `2g`). +
    spark.serializer spark.JavaSerializer