diff options
author | Kay Ousterhout <kayousterhout@gmail.com> | 2013-12-19 21:37:15 -0800 |
---|---|---|
committer | Kay Ousterhout <kayousterhout@gmail.com> | 2013-12-19 21:37:15 -0800 |
commit | 9228ec847e841a17c7dff7e75bc2e06bea799ea4 (patch) | |
tree | 36da47eed091273fe5b59012c2a87f1c3d5f54e1 /core/src | |
parent | 2b0a6e7d9210ed828395243027c7001f7dae77a4 (diff) | |
parent | 40f63eb034ee5669dba87deb5f8f37c10bf5df0c (diff) | |
download | spark-9228ec847e841a17c7dff7e75bc2e06bea799ea4.tar.gz spark-9228ec847e841a17c7dff7e75bc2e06bea799ea4.tar.bz2 spark-9228ec847e841a17c7dff7e75bc2e06bea799ea4.zip |
Merge pull request #1 from aarondav/127
Merge master into 127
Diffstat (limited to 'core/src')
60 files changed, 2024 insertions, 657 deletions
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1ad9240cfa..c6b4ac5192 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { if (!atMost.isFinite()) { awaitResult() - } else { + } else jobWaiter.synchronized { val finishTime = System.currentTimeMillis() + atMost.toMillis while (!isCompleted) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e465fa22c..b4d0b7017c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -244,12 +244,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { case Some(bytes) => return bytes case None => - statuses = mapStatuses(shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that + // out a snapshot of the locations as "statuses"; let's serialize and return that val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working @@ -274,6 +274,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { override def updateEpoch(newEpoch: Long) { // This might be called on the MapOutputTrackerMaster if we're running in local mode. } + + def has(shuffleId: Int): Boolean = { + cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + } } private[spark] object MapOutputTracker { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d5616c274d..1bc3b1972f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -81,7 +80,7 @@ class SparkContext( val sparkHome: String = null, val jars: Seq[String] = Nil, val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set // of data-local splits on host val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = @@ -153,110 +152,11 @@ class SparkContext( executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster - val MESOS_REGEX = """mesos://(.*)""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r - - // When running locally, don't try to re-execute tasks on failure. - val MAX_LOCAL_TASK_FAILURES = 0 - - master match { - case "local" => - val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) - scheduler.initialize(backend) - scheduler - - case LOCAL_N_REGEX(threads) => - val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) - scheduler.initialize(backend) - scheduler - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - val scheduler = new ClusterScheduler(this, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) - scheduler.initialize(backend) - scheduler - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val masterUrls = sparkUrl.split(",").map("spark://" + _) - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - scheduler - - case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val masterUrls = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case MESOS_REGEX(mesosUrl) => - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } else { - new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } - scheduler.initialize(backend) - scheduler - - case _ => - throw new SparkException("Could not parse Master URL: '" + master + "'") - } - } + private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName) taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() ui.start() @@ -1121,6 +1021,136 @@ object SparkContext { .map(Utils.memoryStringToMb) .getOrElse(512) } + + // Creates a task scheduler based on a given master URL. Extracted for testing. + private + def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r + + // When running locally, don't try to re-execute tasks on failure. + val MAX_LOCAL_TASK_FAILURES = 0 + + master match { + case "local" => + val scheduler = new ClusterScheduler(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, 1) + scheduler.initialize(backend) + scheduler + + case LOCAL_N_REGEX(threads) => + val scheduler = new ClusterScheduler(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler + + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + val scheduler = new ClusterScheduler(sc, maxFailures.toInt, isLocal = true) + val backend = new LocalBackend(scheduler, threads.toInt) + scheduler.initialize(backend) + scheduler + + case SPARK_REGEX(sparkUrl) => + val scheduler = new ClusterScheduler(sc) + val masterUrls = sparkUrl.split(",").map("spark://" + _) + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + throw new SparkException( + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + } + + val scheduler = new ClusterScheduler(sc) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val masterUrls = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } + scheduler + + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + scheduler.initialize(backend) + scheduler + + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + + case mesosUrl @ MESOS_REGEX(_) => + MesosNativeLibrary.load() + val scheduler = new ClusterScheduler(sc) + val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs + val backend = if (coarseGrained) { + new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) + } else { + new MesosSchedulerBackend(scheduler, sc, url, appName) + } + scheduler.initialize(backend) + scheduler + + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(sc) + val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) + scheduler.initialize(backend) + scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 043cb183ba..9f02a9b7d3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -26,6 +26,8 @@ import org.apache.spark.storage.StorageLevel import java.lang.Double import org.apache.spark.Partitioner +import scala.collection.JavaConverters._ + class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] @@ -182,6 +184,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { + val result = srdd.histogram(bucketCount) + (result._1, result._2) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[scala.Double]): Array[Long] = { + srdd.histogram(buckets, false) + } + + def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble), evenBuckets) + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a56..132e4fb0d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => @@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,24 +125,24 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { @@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private[spark] object PythonRDD { - - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } +private[spark] object PythonRDD { - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -270,15 +205,32 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + dataOut.writeUTF(str) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + } + + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + writeToFile(items.asScala, filename) } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } @@ -289,17 +241,6 @@ private[spark] object PythonRDD { } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 609464e38d..47db720416 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.URL +import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging { private val files = new TimeStampedHashSet[String] private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private lazy val compressionCodec = CompressionCodec.createCodec() def initialize(isDriver: Boolean) { @@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging { def read[T](id: Long): T = { val url = serverUri + "/" + BroadcastBlockId(id).name val in = { + val httpConnection = new URL(url).openConnection() + httpConnection.setReadTimeout(httpReadTimeout) + val inputStream = httpConnection.getInputStream() if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) + compressionCodec.compressedInputStream(inputStream) } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + new FastBufferedInputStream(inputStream, bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 668032a3a2..0aa8852649 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -1,19 +1,19 @@ /* * - * * Licensed to the Apache Software Foundation (ASF) under one or more - * * contributor license agreements. See the NOTICE file distributed with - * * this work for additional information regarding copyright ownership. - * * The ASF licenses this file to You under the Apache License, Version 2.0 - * * (the "License"); you may not use this file except in compliance with - * * the License. You may obtain a copy of the License at - * * - * * http://www.apache.org/licenses/LICENSE-2.0 - * * - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * * See the License for the specific language governing permissions and - * * limitations under the License. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 308a2bfa22..a724900943 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -17,12 +17,12 @@ package org.apache.spark.deploy -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.actor.ActorSystem import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.{Logging} +import org.apache.spark.util.Utils +import org.apache.spark.Logging import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index caee6b01ab..8332631838 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor.{ActorRef, Actor, Props, Terminated} import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.Logging import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 34ed9c8f73..97176e4f5b 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -20,8 +20,6 @@ package org.apache.spark.executor import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hdfs.DistributedFileSystem -import org.apache.hadoop.fs.LocalFileSystem import scala.collection.JavaConversions._ diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b4892f98f..c0ce46e379 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -61,50 +61,53 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** - * Time when shuffle finishs + * Absolute time when this task finished reading shuffle data */ var shuffleFinishTime: Long = _ /** - * Total number of blocks fetched in a shuffle (remote or local) + * Number of blocks fetched in this shuffle by this task (remote or local) */ var totalBlocksFetched: Int = _ /** - * Number of remote blocks fetched in a shuffle + * Number of remote blocks fetched in this shuffle by this task */ var remoteBlocksFetched: Int = _ /** - * Local blocks fetched in a shuffle + * Number of local blocks fetched in this shuffle by this task */ var localBlocksFetched: Int = _ /** - * Total time that is spent blocked waiting for shuffle to fetch data + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. */ var fetchWaitTime: Long = _ /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time + * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all + * input blocks. Since block fetches are both pipelined and parallelized, this can + * exceed fetchWaitTime and executorRunTime. */ var remoteFetchTime: Long = _ /** - * Total number of remote bytes read from a shuffle + * Total number of remote bytes read from the shuffle by this task */ var remoteBytesRead: Long = _ } class ShuffleWriteMetrics extends Serializable { /** - * Number of bytes written for a shuffle + * Number of bytes written for the shuffle by this task */ var shuffleBytesWritten: Long = _ /** - * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala new file mode 100644 index 0000000000..cdcfec8ca7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit +import java.net.InetSocketAddress + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} + +import org.apache.spark.metrics.MetricsSystem + +class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val GRAPHITE_DEFAULT_PERIOD = 10 + val GRAPHITE_DEFAULT_UNIT = "SECONDS" + val GRAPHITE_DEFAULT_PREFIX = "" + + val GRAPHITE_KEY_HOST = "host" + val GRAPHITE_KEY_PORT = "port" + val GRAPHITE_KEY_PERIOD = "period" + val GRAPHITE_KEY_UNIT = "unit" + val GRAPHITE_KEY_PREFIX = "prefix" + + def propertyToOption(prop: String) = Option(property.getProperty(prop)) + + if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) { + throw new Exception("Graphite sink requires 'host' property.") + } + + if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) { + throw new Exception("Graphite sink requires 'port' property.") + } + + val host = propertyToOption(GRAPHITE_KEY_HOST).get + val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt + + val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match { + case Some(s) => s.toInt + case None => GRAPHITE_DEFAULT_PERIOD + } + + val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) + } + + val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX) + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + + val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .prefixedWith(prefix) + .build(graphite) + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 481ff8c3e0..b1e1576dad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging { extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9b0c882481..0de22f0e06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def compute(split: Partition, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override def getDependencies: Seq[Dependency[_]] = List( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a4bec41752..02d75eccc5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator import org.apache.spark.util.StatCounter import org.apache.spark.{TaskContext, Logging} +import scala.collection.immutable.NumericRange + /** * Extra functions available on RDDs of Doubles through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. @@ -76,4 +78,128 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + // Compute the minimum and the maxium + val (max: Double, min: Double) = self.mapPartitions { items => + Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) => + (x._1.max(e), x._2.min(e)))) + }.reduce { (maxmin1, maxmin2) => + (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) + } + if (max.isNaN() || max.isInfinity || min.isInfinity ) { + throw new UnsupportedOperationException( + "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") + } + val increment = (max-min)/bucketCount.toDouble + val range = if (increment != 0) { + Range.Double.inclusive(min, max, increment) + } else { + List(min, min) + } + val buckets = range.toArray + (buckets, histogram(buckets, true)) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { + if (buckets.length < 2) { + throw new IllegalArgumentException("buckets array must have at least two elements") + } + // The histogramPartition function computes the partail histogram for a given + // partition. The provided bucketFunction determines which bucket in the array + // to increment or returns None if there is no bucket. This is done so we can + // specialize for uniformly distributed buckets and save the O(log n) binary + // search cost. + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): + Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length - 1) + while (iter.hasNext) { + bucketFunction(iter.next()) match { + case Some(x: Int) => {counters(x) += 1} + case _ => {} + } + } + Iterator(counters) + } + // Merge the counters. + def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { + a1.indices.foreach(i => a1(i) += a2(i)) + a1 + } + // Basic bucket function. This works using Java's built in Array + // binary search. Takes log(size(buckets)) + def basicBucketFunction(e: Double): Option[Int] = { + val location = java.util.Arrays.binarySearch(buckets, e) + if (location < 0) { + // If the location is less than 0 then the insertion point in the array + // to keep it sorted is -location-1 + val insertionPoint = -location-1 + // If we have to insert before the first element or after the last one + // its out of bounds. + // We do this rather than buckets.lengthCompare(insertionPoint) + // because Array[Double] fails to override it (for now). + if (insertionPoint > 0 && insertionPoint < buckets.length) { + Some(insertionPoint-1) + } else { + None + } + } else if (location < buckets.length - 1) { + // Exact match, just insert here + Some(location) + } else { + // Exact match to the last element + Some(location - 1) + } + } + // Determine the bucket function in constant time. Requires that buckets are evenly spaced + def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + // If our input is not a number unless the increment is also NaN then we fail fast + if (e.isNaN()) { + return None + } + val bucketNumber = (e - min)/(increment) + // We do this rather than buckets.lengthCompare(bucketNumber) + // because Array[Double] fails to override it (for now). + if (bucketNumber > count || bucketNumber < 0) { + None + } else { + Some(bucketNumber.toInt.min(count - 1)) + } + } + // Decide which bucket function to pass to histogramPartition. We decide here + // rather than having a general function so that the decission need only be made + // once rather than once per shard + val bucketFunction = if (evenBuckets) { + fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + } else { + basicBucketFunction _ + } + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 203179c4ea..ae70d55951 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -20,18 +20,16 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( +private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: Iterator[T] => Iterator[U], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) + f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala deleted file mode 100644 index aea08ff81b..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.{Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the - * TaskContext, the closure can either get access to the interruptible flag or get the index - * of the partition in the RDD. - */ -private[spark] -class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(context, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 165cd412fc..574dd4233f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -33,11 +33,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) + val partitions: Array[Partition] = rdd.partitions + .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = List(partitions(partitionId).index) + override def getParents(partitionId: Int) = { + List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6e88be6f6a..893708f8f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -101,7 +101,7 @@ abstract class RDD[T: ClassManifest]( protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + @transient val partitioner: Option[Partitioner] = None // ======================================================================= // Methods and fields available on all RDDs @@ -114,7 +114,7 @@ abstract class RDD[T: ClassManifest]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - var name: String = null + @transient var name: String = null /** Assign a name to this RDD */ def setName(_name: String) = { @@ -123,7 +123,7 @@ abstract class RDD[T: ClassManifest]( } /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass + @transient var generator = Utils.getCallSiteInfo.firstUserClass /** Reset generator*/ def setGenerator(_generator: String) = { @@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** * Return an RDD created by piping elements to a forked external process. * The print behavior can be customized by providing two functions. @@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitions[U: ClassManifest]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitionsWithIndex[U: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) - new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest]( def mapPartitionsWithContext[U: ClassManifest]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest]( def mapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.map(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest]( def flatMapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex { (index, iter) => + val a = constructA(index) iter.map(t => {f(t, a); t}) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) + }.foreach(_ => {}) } /** @@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.filter(t => p(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) + }, preservesPartitioning = true) } /** @@ -546,19 +543,34 @@ abstract class RDD[T: ClassManifest]( * of elements in each partition. */ def zipPartitions[B: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning) + + def zipPartitions[B: ClassManifest, V: ClassManifest] (rdd2: RDD[B]) (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) + + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] (rdd2: RDD[B], rdd3: RDD[C]) (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) + + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) // Actions (launch a job to return a value to the user program) @@ -928,7 +940,7 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] @@ -943,7 +955,7 @@ abstract class RDD[T: ClassManifest]( def context = sc // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false + @transient private var doCheckpointCalled = false /** * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 31e6fd519d..a97d2a01c8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -22,7 +22,8 @@ import java.io.{ObjectOutputStream, IOException} private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]]) + @transient rdds: Seq[RDD[_]], + @transient val preferredLocations: Seq[String]) extends Partition { override val index: Int = idx @@ -39,31 +40,29 @@ private[spark] class ZippedPartitionsPartition( abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( sc: SparkContext, - var rdds: Seq[RDD[_]]) + var rdds: Seq[RDD[_]], + preservesPartitioning: Boolean = false) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + override val partitioner = + if (preservesPartitioning) firstParent[Any].partitioner else None + override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { + val numParts = rdds.head.partitions.size + if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) + Array.tabulate[Partition](numParts) { i => + val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct + new ZippedPartitionsPartition(i, rdds, locs) } - array } override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } + s.asInstanceOf[ZippedPartitionsPartition].preferredLocations } override def clearDependencies() { @@ -76,8 +75,9 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + var rdd2: RDD[B], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -97,8 +97,9 @@ class ZippedPartitionsRDD3 f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + var rdd3: RDD[C], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -122,8 +123,9 @@ class ZippedPartitionsRDD4 var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + var rdd4: RDD[D], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala index 2e4ba53d9b..a4c2e31012 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -101,8 +101,8 @@ private[spark] class ClusterScheduler( this.dagScheduler = dagScheduler } - def initialize(context: SchedulerBackend) { - backend = context + def initialize(backend: SchedulerBackend) { + this.backend = backend // temporarily set rootPool name to empty rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { @@ -174,7 +174,9 @@ private[spark] class ClusterScheduler( backend.killTask(tid, execId) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + tsm.removeAllRunningTasks() + taskSetFinished(tsm) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8..f9cd021dd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -110,30 +110,10 @@ class DAGScheduler( // resubmit failed stages val POLL_TIMEOUT = 10L - private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { - override def preStart() { - context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { - if (failed.size > 0) { - resubmitFailedStages() - } - } - } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - def receive = { - case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 - if (!processEvent(event)) - submitWaitingStages() - else - context.stop(self) - } - })) + private var eventProcessActor: ActorRef = _ private[scheduler] val nextJobId = new AtomicInteger(0) @@ -141,9 +121,13 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) - private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] + + private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -174,6 +158,56 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + /** + * Starts the event processing actor. The actor has two responsibilities: + * + * 1. Waits for events like job submission, task finished, task failure etc., and calls + * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them. + * 2. Schedules a periodical task to resubmit failed stages. + * + * NOTE: the actor cannot be started in the constructor, because the periodical task references + * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus + * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed. + */ + def start() { + eventProcessActor = env.actorSystem.actorOf(Props(new Actor { + /** + * A handle to the periodical task, used to cancel the task when the actor is stopped. + */ + var resubmissionTask: Cancellable = _ + + override def preStart() { + /** + * A message is sent to the actor itself periodically to remind the actor to resubmit failed + * stages. In this way, stage resubmission can be done within the same thread context of + * other event processing logic to avoid unnecessary synchronization overhead. + */ + resubmissionTask = context.system.scheduler.schedule( + RESUBMIT_TIMEOUT.millis, RESUBMIT_TIMEOUT.millis, self, ResubmitFailedStages) + } + + /** + * The main event loop of the DAG scheduler. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + /** + * All events are forwarded to `processEvent()`, so that the event processing logic can + * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite` + * for details. + */ + if (!processEvent(event)) { + submitWaitingStages() + } else { + resubmissionTask.cancel() + context.stop(self) + } + } + })) + } + def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -202,16 +236,16 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) + val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } } /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. + * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation + * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. */ private def newStage( rdd: RDD[_], @@ -221,21 +255,45 @@ class DAGScheduler( callSite: Option[String] = None) : Stage = { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) stageToInfos(stage) = new StageInfo(stage) stage } /** + * Create a shuffle map Stage for the given RDD. The stage will also be associated with the + * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * present in the MapOutputTracker, then the number and location of available outputs are + * recovered from the MapOutputTracker + */ + private def newOrUsedStage( + rdd: RDD[_], + numTasks: Int, + shuffleDep: ShuffleDependency[_,_], + jobId: Int, + callSite: Option[String] = None) + : Stage = + { + val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + if (mapOutputTracker.has(shuffleDep.shuffleId)) { + val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) + val locs = MapOutputTracker.deserializeMapStatuses(serLocs) + for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i)) + stage.numAvailableOutputs = locs.size + } else { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + } + stage + } + + /** * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided jobId if they haven't already been created with a lower jobId. */ @@ -287,6 +345,89 @@ class DAGScheduler( } /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + def updateJobIdStageIdMapsList(stages: List[Stage]) { + if (!stages.isEmpty) { + val s = stages.head + stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id + val parents = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) + } + } + updateJobIdStageIdMapsList(List(stage)) + } + + /** + * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + */ + private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { + val registeredStages = jobIdToStageIds(jobId) + val independentStages = new HashSet[Int]() + if (registeredStages.isEmpty) { + logError("No stages registered for job " + jobId) + } else { + stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + case (stageId, jobSet) => + if (!jobSet.contains(jobId)) { + logError("Job %d not registered for stage %d even though that stage was registered for the job" + .format(jobId, stageId)) + } else { + def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + + jobSet -= jobId + if (jobSet.isEmpty) { // no other job needs this stage + independentStages += stageId + removeStage(stageId) + } + } + } + } + independentStages.toSet + } + + private def jobIdToStageIdsRemove(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to remove unregistered job " + jobId) + } else { + removeJobAndIndependentStages(jobId) + jobIdToStageIds -= jobId + } + } + + /** * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. */ @@ -375,13 +516,25 @@ class DAGScheduler( } /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. + * Process one event retrieved from the event processing actor. + * + * @param event The event to be processed. + * @return `true` if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + var finalStage: Stage = null + try { + // New stage creation at times and if its not protected, the scheduler thread is killed. + // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return false + } val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -391,37 +544,31 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { // Compute very short actions like first() or take() with no parent stages locally. + listenerBus.post(SparkListenerJobStart(job, Array(), properties)) runLocally(job) } else { - listenerBus.post(SparkListenerJobStart(job, properties)) idToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job + listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) } case JobCancelled(jobId) => - // Cancel a job: find all the running stages that are linked to this job, and cancel them. - running.filter(_.jobId == jobId).foreach { stage => - taskSched.cancelTasks(stage.id) - } + handleJobCancellation(jobId) case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) - .map(_.jobId) - if (!jobIds.isEmpty) { - running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => - taskSched.cancelTasks(stage.id) - } - } + val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach { handleJobCancellation } case AllJobsCancelled => // Cancel all running jobs. - running.foreach { stage => - taskSched.cancelTasks(stage.id) - } + running.map(_.jobId).foreach { handleJobCancellation } + activeJobs.clear() // These should already be empty by this point, + idToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -430,6 +577,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => @@ -440,7 +599,12 @@ class DAGScheduler( handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } + + case ResubmitFailedStages => + if (failed.size > 0) { + resubmitFailedStages() + } case StopDAGScheduler => // Cancel any active jobs @@ -502,6 +666,7 @@ class DAGScheduler( // Broken out for easier testing in DAGSchedulerSuite. protected def runLocallyWithinThread(job: ActiveJob) { + var jobResult: JobResult = JobSucceeded try { SparkEnv.set(env) val rdd = job.finalStage.rdd @@ -516,31 +681,59 @@ class DAGScheduler( } } catch { case e: Exception => + jobResult = JobFailed(e, Some(job.finalStage)) job.listener.jobFailed(e) + } finally { + val s = job.finalStage + stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, + stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through + stageToInfos -= s // completion events or stage abort + jobIdToStageIds -= job.jobId + listenerBus.post(SparkListenerJobEnd(job, jobResult)) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + if (stageIdToJobIds.contains(stage.id)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted + jobsThatUseStage.find(idToActiveJob.contains(_)) + } else { + None } } /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage } - waiting += stage } + } else { + abortStage(stage, "No active job for stage " + stage.id) } } + /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { + private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -561,7 +754,7 @@ class DAGScheduler( } } - val properties = if (idToActiveJob.contains(stage.jobId)) { + val properties = if (idToActiveJob.contains(jobId)) { idToActiveJob(stage.jobId).properties } else { //this stage will be assigned to "default" pool @@ -643,6 +836,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + jobIdToStageIdsRemove(job.jobId) listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -679,7 +873,7 @@ class DAGScheduler( changeEpoch = true) } clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { + if (stage.outputLocs.exists(_ == Nil)) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + stage + " (" + stage.name + @@ -696,9 +890,12 @@ class DAGScheduler( } waiting --= newlyRunnable running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { + for { + stage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(stage) + } { logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) + submitMissingTasks(stage, jobId) } } } @@ -782,21 +979,42 @@ class DAGScheduler( } } + private def handleJobCancellation(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + val independentStages = removeJobAndIndependentStages(jobId) + independentStages.foreach { taskSched.cancelTasks } + val error = new SparkException("Job %d cancelled".format(jobId)) + val job = idToActiveJob(jobId) + job.listener.jobFailed(error) + jobIdToStageIds -= jobId + activeJobs -= job + idToActiveJob -= jobId + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) + } + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) + jobIdToStageIdsRemove(job.jobId) idToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -867,25 +1085,24 @@ class DAGScheduler( } private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) + Map( + "stageIdToStage" -> stageIdToStage, + "shuffleToMapStage" -> shuffleToMapStage, + "pendingTasks" -> pendingTasks, + "stageToInfos" -> stageToInfos, + "jobIdToStageIds" -> jobIdToStageIds, + "stageIdToJobIds" -> stageIdToJobIds). + foreach { case(s, t) => { + val sizeBefore = t.size + t.clearOldValues(cleanupTime) + logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) + }} } def stop() { - eventProcessActor ! StopDAGScheduler + if (eventProcessActor != null) { + eventProcessActor ! StopDAGScheduler + } metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 708d221d60..add1187613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,12 +65,13 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] -case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 58f238d8cf..b026f860a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -31,6 +31,7 @@ private[spark] class JobWaiter[T]( private var finishedTasks = 0 // Is the job as a whole finished (succeeded or failed)? + @volatile private var _jobFinished = totalTasks == 0 def jobFinished = _jobFinished diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 1dc71a0428..0f2deb4bcb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask( var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() + writer.close() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask( } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (shuffle != null) { - shuffle.writers.foreach(_.revertPartialWrites()) + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } } throw e } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.writers.foreach(_.close()) shuffle.releaseWriters(success) } // Execute the callbacks on task completion. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index a35081f7b1..3841b5616d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult( case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) +case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null) extends SparkListenerEvents case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc8..e9f2198a00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6..3c22edd524 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7989e6ab32..5a279f970c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -583,7 +583,7 @@ private[spark] class TaskSetManager( runningTasks = runningTasksSet.size } - private def removeAllRunningTasks() { + private[scheduler] def removeAllRunningTasks() { val numRunningTasks = runningTasksSet.size runningTasksSet.clear() if (parent != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 3af02b42b2..6f637f4613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -200,6 +200,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } override def stop() { + stopExecutors() try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 1caa88e61f..ee762fe621 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -64,7 +64,6 @@ private[spark] class SimrSchedulerBackend( val conf = new Configuration() val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) - super.stopExecutors() super.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index d9b941d694..6b5f1a5dc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -78,7 +78,7 @@ private[spark] class LocalActor( * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int) +private[spark] class LocalBackend(scheduler: ClusterScheduler, val totalCores: Int) extends SchedulerBackend with ExecutorBackend { var localActor: ActorRef = null diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 469e68fed7..b4451fc7b8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -93,6 +93,8 @@ class DiskBlockObjectWriter( def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + override def close() = out.close() + override def flush() = out.flush() } private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 2f1b049ce4..e828e1d1c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -62,7 +62,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 632ff047d1..b5596dffd3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -101,7 +101,7 @@ class StorageLevel private( var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") + result += (if (deserialized) "Deserialized " else "Serialized ") result += "%sx Replicated".format(replication) result } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 42e9be6e19..e596690bc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { </tr> } - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execTable = UIUtils.listingTable(execHead, execRow, execInfo) val content = @@ -99,16 +99,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) } - def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + def getExecInfo(statusId: Int): Seq[String] = { + val status = sc.getExecutorStorageStatus(statusId) + val execId = status.blockManagerId.executorId + val hostPort = status.blockManagerId.hostPort + val rddBlocks = status.blocks.size.toString + val memUsed = status.memUsed().toString + val maxMem = status.maxMem.toString + val diskUsed = status.diskUsed().toString + val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size + val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks Seq( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c1c7aa70e6..69f9446bab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) { var activeTime = 0L listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished) + val summary = <div> <ul class="unstyled"> <li> - <strong>CPU time: </strong> + <strong>Total duration across all tasks: </strong> {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) @@ -104,6 +106,33 @@ private[spark] class StagePage(parent: JobProgressUI) { val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( ms => parent.formatDuration(ms.toLong)) + val gettingResultTimes = validTasks.map{case (info, metrics, exception) => + if (info.gettingResultTime > 0) { + (info.finishTime - info.gettingResultTime).toDouble + } else { + 0.0 + } + } + val gettingResultQuantiles = ("Time spent fetching task results" +: + Distribution(gettingResultTimes).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map{case (info, metrics, exception) => + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime).toDouble + } else { + (info.finishTime - info.launchTime).toDouble + } + } + totalExecutionTime - metrics.get.executorRunTime + } + val schedulerDelayQuantiles = ("Scheduler delay" +: + Distribution(schedulerDelays).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + def getQuantileCols(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) @@ -119,7 +148,10 @@ private[spark] class StagePage(parent: JobProgressUI) { } val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + val listings: Seq[Seq[String]] = Seq( + serviceQuantiles, + gettingResultQuantiles, + schedulerDelayQuantiles, if (hasShuffleRead) shuffleReadQuantiles else Nil, if (hasShuffleWrite) shuffleWriteQuantiles else Nil) @@ -133,7 +165,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ - <h4>Tasks</h4> ++ taskTable; + <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) } @@ -152,21 +184,18 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - var shuffleReadSortable: String = "" - var shuffleReadReadable: String = "" - if (shuffleRead) { - shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() - shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") - } + val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead} + val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") + val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("") - var shuffleWriteSortable: String = "" - var shuffleWriteReadable: String = "" - if (shuffleWrite) { - shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() - shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") - } + val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten} + val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("") + + val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime} + val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") + val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms => + if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("") <tr> <td>{info.index}</td> @@ -187,8 +216,8 @@ private[spark] class StagePage(parent: JobProgressUI) { </td> }} {if (shuffleWrite) { - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + <td sorttable_customkey={writeTimeSortable}> + {writeTimeReadable} </td> <td sorttable_customkey={shuffleWriteSortable}> {shuffleWriteReadable} diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index f60deafc6f..8bb4ee3bfa 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 + private var growThreshold = LOAD_FACTOR * capacity // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { return data(2 * pos + 1).asInstanceOf[V] } else if (curKey.eq(null)) { return null.asInstanceOf[V] @@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi haveNullValue = true return } - val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) - if (isNewEntry) { - incrementSize() + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = k + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + incrementSize() // Since we added a new key + return + } else if (k.eq(curKey) || k.equals(curKey)) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } } } @@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] return newValue @@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { curSize += 1 - if (curSize > LOAD_FACTOR * capacity) { + if (curSize > growThreshold) { growTable() } } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ private def rehash(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } - - /** - * Put an entry into a table represented by data, returning true if - * this increases the size of the table or false otherwise. Assumes - * that "data" has at least one empty slot. - */ - private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { - val mask = (data.length / 2) - 1 - var pos = rehash(key.hashCode) & mask - var i = 1 - while (true) { - val curKey = data(2 * pos) - if (curKey.eq(null)) { - data(2 * pos) = key - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return true - } else if (curKey.eq(key) || curKey == key) { - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return false - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - return false // Never reached but needed to keep compiler happy + it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) } /** Double the table's size and re-hash everything */ @@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi throw new Exception("Can't make capacity bigger than 2^29 elements") } val newData = new Array[AnyRef](2 * newCapacity) - var pos = 0 - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - putInto(newData, data(2 * pos), data(2 * pos + 1)) + val newMask = newCapacity - 1 + // Insert all our old values into the new array. Note that because our old keys are + // unique, there's no need to check for equality here when we insert. + var oldPos = 0 + while (oldPos < capacity) { + if (!data(2 * oldPos).eq(null)) { + val key = data(2 * oldPos) + val value = data(2 * oldPos + 1) + var newPos = rehash(key.hashCode) & newMask + var i = 1 + var keepGoing = true + while (keepGoing) { + val curKey = newData(2 * newPos) + if (curKey.eq(null)) { + newData(2 * newPos) = key + newData(2 * newPos + 1) = value + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } data = newData capacity = newCapacity - mask = newCapacity - 1 + mask = newMask + growThreshold = LOAD_FACTOR * newCapacity } private def nextPowerOf2(n: Int): Int = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe932d8ede..a79e64e810 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -823,4 +823,28 @@ private[spark] object Utils extends Logging { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } + + /** + * Method executed for repeating a task for side effects. + * Unlike a for comprehension, it permits JVM JIT optimization + */ + def times(numIters: Int)(f: => Unit): Unit = { + var i = 0 + while (i < numIters) { + f + i += 1 + } + } + + /** + * Timing method based on iterations that permit JVM JIT optimization. + * @param numIters number of iterations + * @param f function to be executed + */ + def timeIt(numIters: Int)(f: => Unit): Long = { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala new file mode 100644 index 0000000000..e9907e6c85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.{Random => JavaRandom} +import org.apache.spark.util.Utils.timeIt + +/** + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a> + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ +private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { + + def this() = this(System.nanoTime) + + private var seed = init + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + override protected def next(bits: Int): Int = { + var nextSeed = seed ^ (seed << 21) + nextSeed ^= (nextSeed >>> 35) + nextSeed ^= (nextSeed << 4) + seed = nextSeed + (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] + } +} + +/** Contains benchmark method and main method to run benchmark of the RNG */ +private[spark] object XORShiftRandom { + + /** + * Main method for running benchmark + * @param args takes one argument - the number of random numbers to generate + */ + def main(args: Array[String]): Unit = { + if (args.length != 1) { + println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") + println("Usage: XORShiftRandom number_of_random_numbers_to_generate") + System.exit(1) + } + println(benchmark(args(0).toInt)) + } + + /** + * @param numIters Number of random numbers to generate while running the benchmark + * @return Map of execution times for {@link java.util.Random java.util.Random} + * and XORShift + */ + def benchmark(numIters: Int) = { + + val seed = 1L + val million = 1e6.toInt + val javaRand = new JavaRandom(seed) + val xorRand = new XORShiftRandom(seed) + + // this is just to warm up the JIT - we're not timing anything + timeIt(1e6.toInt) { + javaRand.nextInt() + xorRand.nextInt() + } + + val iters = timeIt(numIters)(_) + + /* Return results as a map instead of just printing to screen + in case the user wants to do something with them */ + Map("javaTime" -> iters {javaRand.nextInt()}, + "xorTime" -> iters {xorRand.nextInt()}) + + } + +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4592e4f939..40986e3731 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -79,6 +79,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _growThreshold = (loadFactor * _capacity).toInt protected var _bitset = new BitSet(_capacity) @@ -115,7 +116,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. */ - def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + def addWithoutResize(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (_data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } /** * Rehash the set if it is overloaded. @@ -126,7 +149,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > loadFactor * _capacity) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -161,37 +184,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) /** - * Put an entry into the set. Return the position where the key is placed. In addition, the - * highest bit in the returned position is set if the key exists prior to this put. - * - * This function assumes the data array has at least one empty slot. - */ - private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { - val mask = data.length - 1 - var pos = hashcode(hasher.hash(k)) & mask - var i = 1 - while (true) { - if (!bitset.get(pos)) { - // This is a new key. - data(pos) = k - bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK - } else if (data(pos) == k) { - // Found an existing key. - return pos - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS - } - - /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version * of putInto). @@ -204,34 +196,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 - require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - allocateFunc(newCapacity) - val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) - var pos = 0 - _size = 0 - while (pos < _capacity) { - if (_bitset.get(pos)) { - val newPos = putInto(newBitset, newData, _data(pos)) - moveFunc(pos, newPos & POSITION_MASK) + val newData = new Array[T](newCapacity) + val newMask = newCapacity - 1 + + var oldPos = 0 + while (oldPos < capacity) { + if (_bitset.get(oldPos)) { + val key = _data(oldPos) + var newPos = hashcode(hasher.hash(key)) & newMask + var i = 1 + var keepGoing = true + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!newBitset.get(newPos)) { + // Inserting the key at newPos + newData(newPos) = key + newBitset.set(newPos) + moveFunc(oldPos, newPos) + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } + _bitset = newBitset _data = newData _capacity = newCapacity - _mask = newCapacity - 1 + _mask = newMask + _growThreshold = (loadFactor * newCapacity).toInt } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } + private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 369519c559..20554f0aab 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -17,35 +17,51 @@ package org.apache.spark.util.collection -/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +/** + * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types. + */ private[spark] class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) { - private var numElements = 0 - private var array: Array[V] = _ + private var _numElements = 0 + private var _array: Array[V] = _ // NB: This must be separate from the declaration, otherwise the specialized parent class - // will get its own array with the same initial size. TODO: Figure out why... - array = new Array[V](initialSize) + // will get its own array with the same initial size. + _array = new Array[V](initialSize) def apply(index: Int): V = { - require(index < numElements) - array(index) + require(index < _numElements) + _array(index) } def +=(value: V) { - if (numElements == array.length) { resize(array.length * 2) } - array(numElements) = value - numElements += 1 + if (_numElements == _array.length) { + resize(_array.length * 2) + } + _array(_numElements) = value + _numElements += 1 } - def length = numElements + def capacity: Int = _array.length + + def length: Int = _numElements + + def size: Int = _numElements + + /** Gets the underlying array backing this vector. */ + def array: Array[V] = _array - def getUnderlyingArray = array + /** Trims this vector so that the capacity is equal to the size. */ + def trim(): PrimitiveVector[V] = resize(size) /** Resizes the array, dropping elements if the total length decreases. */ - def resize(newLength: Int) { + def resize(newLength: Int): PrimitiveVector[V] = { val newArray = new Array[V](newLength) - array.copyToArray(newArray) - array = newArray + _array.copyToArray(newArray) + _array = newArray + if (newLength < _numElements) { + _numElements = newLength + } + this } } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f26c44d3e7..d2226aa5a5 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithContextRDD(r, - (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 352036f182..4234f6eac7 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable { } @Test + public void javaDoubleRDDHistoGram() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + // Test using generated buckets + Tuple2<double[], long[]> results = rdd.histogram(2); + double[] expected_buckets = {1.0, 2.5, 4.0}; + long[] expected_counts = {2, 2}; + Assert.assertArrayEquals(expected_buckets, results._1, 0.1); + Assert.assertArrayEquals(expected_counts, results._2); + // Test with provided buckets + long[] histogram = rdd.histogram(expected_buckets); + Assert.assertArrayEquals(expected_counts, histogram); + } + + @Test public void map() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() { diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index d8a0e983b2..1121e06e2e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -114,7 +114,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf // Once A is cancelled, job B should finish fairly quickly. assert(jobB.get() === 100) } - +/* test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued @@ -148,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - + */ def testCount() { // Cancel before launching any tasks { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 459e257d79..8dd5786da6 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 21f16ef2c6..0000000000 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) - } - def compute(split: Partition, context: TaskContext) = {Iterator()} - } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) - assert(prunedRDD.partitions.length == 1) - } -} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7d938917f2..1374d01774 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext { .filter(_ >= 0.0) // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); + val stats: StatCounter = rdd.stats() + assert(abs(6.0 - stats.sum) < 0.01) + assert(abs(6.0/2 - rdd.mean) < 0.01) + assert(abs(1.0 - rdd.variance) < 0.01) + assert(abs(1.0 - rdd.stdev) < 0.01) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala new file mode 100644 index 0000000000..d4a7a11515 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.scheduler.{ClusterScheduler, TaskScheduler} +import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.scheduler.local.LocalBackend + +class SparkContextSchedulerCreationSuite + extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + + def createTaskScheduler(master: String): ClusterScheduler = { + // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the + // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. + sc = new SparkContext("local", "test") + val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) + val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + sched.asInstanceOf[ClusterScheduler] + } + + test("bad-master") { + val e = intercept[SparkException] { + createTaskScheduler("localhost:1234") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("local") { + val sched = createTaskScheduler("local") + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 1) + case _ => fail() + } + } + + test("local-n") { + val sched = createTaskScheduler("local[5]") + assert(sched.maxTaskFailures === 0) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 5) + case _ => fail() + } + } + + test("local-n-failures") { + val sched = createTaskScheduler("local[4, 2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === 4) + case _ => fail() + } + } + + test("simr") { + createTaskScheduler("simr://uri").backend match { + case s: SimrSchedulerBackend => // OK + case _ => fail() + } + } + + test("local-cluster") { + createTaskScheduler("local-cluster[3, 14, 512]").backend match { + case s: SparkDeploySchedulerBackend => // OK + case _ => fail() + } + } + + def testYarn(master: String, expectedClassName: String) { + try { + val sched = createTaskScheduler(master) + assert(sched.getClass === Class.forName(expectedClassName)) + } catch { + case e: SparkException => + assert(e.getMessage.contains("YARN mode not available")) + logWarning("YARN not available, could not test actual YARN scheduler creation") + case e: Throwable => fail(e) + } + } + + test("yarn-standalone") { + testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + + test("yarn-client") { + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + } + + def testMesos(master: String, expectedClass: Class[_]) { + try { + val sched = createTaskScheduler(master) + assert(sched.backend.getClass === expectedClass) + } catch { + case e: UnsatisfiedLinkError => + assert(e.getMessage.contains("no mesos in")) + logWarning("Mesos not available, could not test actual Mesos scheduler creation") + case e: Throwable => fail(e) + } + } + + test("mesos fine-grained") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + } + + test("mesos coarse-grained") { + System.setProperty("spark.mesos.coarse", "true") + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + } + + test("mesos with zookeeper") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index da032b17d9..0d4c10db8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore +import scala.concurrent.{Await, TimeoutException} +import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts sem.acquire(2) } } + + /** + * Awaiting FutureAction results + */ + test("FutureAction result, infinite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration.Inf) === 100) + } + + test("FutureAction result, finite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration(30, "seconds")) === 100) + } + + test("FutureAction result, timeout") { + val f = sc.parallelize(1 to 100, 4) + .mapPartitions(itr => { Thread.sleep(20); itr }) + .countAsync() + intercept[TimeoutException] { + Await.result(f, Duration(20, "milliseconds")) + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala new file mode 100644 index 0000000000..7f50a5a47c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.math.abs +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark._ + +class DoubleRDDSuite extends FunSuite with SharedSparkContext { + // Verify tests on the histogram functionality. We test with both evenly + // and non-evenly spaced buckets as the bucket lookup function changes. + test("WorksOnEmpty") { + // Make sure that it works on an empty input + val rdd: RDD[Double] = sc.parallelize(Seq()) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithOneBucket") { + // Verify that if all of the elements are out of range the counts are zero + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucket") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucketExactMatch") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(1.0, 4.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { + // Verify that out of range works with two un even buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 4.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBuckets") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBucketsAndNaN") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithFourUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithUnevenBucketsAndNaN") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket and an inifity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 4) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithInfiniteBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(1, 1) + assert(histogramResults === expectedHistogramResults) + } + // Test the failure mode with an invalid bucket array + test("ThrowsExceptionOnInvalidBucketArray") { + val rdd = sc.parallelize(Seq(1.0)) + // Empty array + intercept[IllegalArgumentException] { + val buckets = Array.empty[Double] + val result = rdd.histogram(buckets) + } + // Single element array + intercept[IllegalArgumentException] { + val buckets = Array(1.0) + val result = rdd.histogram(buckets) + } + } + + // Test automatic histogram function + test("WorksWithoutBucketsBasic") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicSingleElement") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(1) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicNoRange") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 1, 1, 1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsBasicTwo") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(2) + val expectedHistogramResults = Array(2, 2) + val expectedHistogramBuckets = Array(1.0, 2.5, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithMoreRequestedThanElements") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2)) + val (histogramBuckets, histogramResults) = rdd.histogram(10) + val expectedHistogramResults = + Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) + val expectedHistogramBuckets = + Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + // Test the failure mode with an invalid RDD + test("ThrowsExceptionOnInvalidRDDs") { + // infinity + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, 1.0/0.0)) + val result = rdd.histogram(1) + } + // NaN + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, Double.NaN)) + val result = rdd.histogram(1) + } + // Empty + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq()) + val result = rdd.histogram(1) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..53a7b7c44d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.apache.spark.{TaskContext, Partition, SharedSparkContext} + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + + test("Pruned Partitions inherit locality prefs correctly") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() + } + } + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) + } + + + test("Pruned Partitions can be unioned ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + assert(merged.count() == 2) + val take = merged.take(2) + assert(take.apply(0) == 4) + assert(take.apply(1) == 6) + } +} + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a4d41ebbff..706d84a58b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(rdd, Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("local job") { @@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial job w/ dependency") { @@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cache location preferences w/ dependency") { @@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") + assertDataStructuresEmpty } test("run trivial shuffle") { @@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial shuffle with fetch failure") { @@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("ignore late map task completions") { @@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("run trivial shuffle with out-of-band failure and retry") { @@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("recursive shuffle failures") { + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("recursive shuffle failures") { val shuffleOneRdd = makeRdd(2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) @@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cached post-shuffle") { @@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } /** @@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345, 0) + private def assertDataStructuresEmpty = { + assert(scheduler.pendingTasks.isEmpty) + assert(scheduler.activeJobs.isEmpty) + assert(scheduler.failed.isEmpty) + assert(scheduler.idToActiveJob.isEmpty) + assert(scheduler.jobIdToStageIds.isEmpty) + assert(scheduler.stageIdToJobIds.isEmpty) + assert(scheduler.stageIdToStage.isEmpty) + assert(scheduler.stageToInfos.isEmpty) + assert(scheduler.resultStageToJob.isEmpty) + assert(scheduler.running.isEmpty) + assert(scheduler.shuffleToMapStage.isEmpty) + assert(scheduler.waiting.isEmpty) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 984881861c..002368ff55 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + val WAIT_TIMEOUT_MILLIS = 10000 test("inner method") { sc = new SparkContext("local", "joblogger") @@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) @@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() - + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + joblogger.onJobStartCount should be (1) joblogger.onJobEndCount should be (1) joblogger.onTaskEndCount should be (8) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 0b9056344c..ef4c4c0f14 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -5,9 +5,9 @@ import java.io.{FileWriter, File} import scala.collection.mutable import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { val rootDir0 = Files.createTempDir() rootDir0.deleteOnExit() @@ -16,6 +16,12 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val rootDirs = rootDir0.getName + "," + rootDir1.getName println("Created root dirs: " + rootDirs) + // This suite focuses primarily on consolidation features, + // so we coerce consolidation if not already enabled. + val consolidateProp = "spark.shuffle.consolidateFiles" + val oldConsolidate = Option(System.getProperty(consolidateProp)) + System.setProperty(consolidateProp, "true") + val shuffleBlockManager = new ShuffleBlockManager(null) { var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -23,6 +29,10 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { var diskBlockManager: DiskBlockManager = _ + override def afterAll() { + oldConsolidate.map(c => System.setProperty(consolidateProp, c)) + } + override def beforeEach() { diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) shuffleBlockManager.idToSegmentMap.clear() diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala new file mode 100644 index 0000000000..b78367b6ca --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Random +import org.scalatest.FlatSpec +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.Utils.times + +class XORShiftRandomSuite extends FunSuite with ShouldMatchers { + + def fixture = new { + val seed = 1L + val xorRand = new XORShiftRandom(seed) + val hundMil = 1e8.toInt + } + + /* + * This test is based on a chi-squared test for randomness. The values are hard-coded + * so as not to create Spark's dependency on apache.commons.math3 just to call one + * method for calculating the exact p-value for a given number of random numbers + * and bins. In case one would want to move to a full-fledged test based on + * apache.commons.math3, the relevant class is here: + * org.apache.commons.math3.stat.inference.ChiSquareTest + */ + test ("XORShift generates valid random numbers") { + + val f = fixture + + val numBins = 10 + // create 10 bins + val bins = Array.fill(numBins)(0) + + // populate bins based on modulus of the random number + times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} + + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following + * results: + * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699)) + * Chi-squared test for given probabilities + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * 10002286, 9998699) + * X-squared = 11.975, df = 9, p-value = 0.2147 + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * random numbers + * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared + * is greater than or equal to that number. + */ + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) + + } + +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ca3f684668..63e874fed3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -2,8 +2,20 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite - -class OpenHashMapSuite extends FunSuite { +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator + +class OpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive value (int)") { + val capacity = 1024 + val map = new OpenHashMap[String, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset. + val expectedSize = capacity * (64 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new OpenHashMap[String, Int](1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 4e11e8a628..4768a1e60b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -1,9 +1,27 @@ package org.apache.spark.util.collection import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite { + +class OpenHashSetSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive int") { + val loadFactor = 0.7 + val set = new OpenHashSet[Int](64, loadFactor) + for (i <- 0 until 1024) { + set.add(i) + } + assert(set.size === 1024) + assert(set.capacity > 1024) + val actualSize = SizeEstimator.estimate(set) + // 32 bits for the ints + 1 bit for the bitset + val expectedSize = set.capacity * (32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("primitive int") { val set = new OpenHashSet[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index dfd6aed2c4..2220b4f0d5 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -2,8 +2,20 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashSetSuite extends FunSuite { +class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive key, value (int, int)") { + val capacity = 1024 + val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 32 bit for keys, 32 bit for values, and 1 bit for the bitset. + val expectedSize = capacity * (32 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala new file mode 100644 index 0000000000..970dade628 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + +import org.apache.spark.util.SizeEstimator + +class PrimitiveVectorSuite extends FunSuite { + + test("primitive value") { + val vector = new PrimitiveVector[Int] + + for (i <- 0 until 1000) { + vector += i + assert(vector(i) === i) + } + + assert(vector.size === 1000) + assert(vector.size == vector.length) + intercept[IllegalArgumentException] { + vector(1000) + } + + for (i <- 0 until 1000) { + assert(vector(i) == i) + } + } + + test("non-primitive value") { + val vector = new PrimitiveVector[String] + + for (i <- 0 until 1000) { + vector += i.toString + assert(vector(i) === i.toString) + } + + assert(vector.size === 1000) + assert(vector.size == vector.length) + intercept[IllegalArgumentException] { + vector(1000) + } + + for (i <- 0 until 1000) { + assert(vector(i) == i.toString) + } + } + + test("ideal growth") { + val vector = new PrimitiveVector[Long](initialSize = 1) + vector += 1 + for (i <- 1 until 1024) { + vector += i + assert(vector.size === i + 1) + assert(vector.capacity === Integer.highestOneBit(i) * 2) + } + assert(vector.capacity === 1024) + vector += 1024 + assert(vector.capacity === 2048) + } + + test("ideal size") { + val vector = new PrimitiveVector[Long](8192) + for (i <- 0 until 8192) { + vector += i + } + assert(vector.size === 8192) + assert(vector.capacity === 8192) + val actualSize = SizeEstimator.estimate(vector) + val expectedSize = 8192 * 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array. + assert(actualSize < expectedSize * 1.1) + } + + test("resizing") { + val vector = new PrimitiveVector[Long] + for (i <- 0 until 4097) { + vector += i + } + assert(vector.size === 4097) + assert(vector.capacity === 8192) + vector.trim() + assert(vector.size === 4097) + assert(vector.capacity === 4097) + vector.resize(5000) + assert(vector.size === 4097) + assert(vector.capacity === 5000) + vector.resize(4000) + assert(vector.size === 4000) + assert(vector.capacity === 4000) + vector.resize(5000) + assert(vector.size === 4000) + assert(vector.capacity === 5000) + for (i <- 0 until 4000) { + assert(vector(i) == i) + } + intercept[IllegalArgumentException] { + vector(4000) + } + } +} |